Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change 8bit optimizer blocksize 2048->256; additional bf16 support #1365

Merged
merged 9 commits into from
Sep 20, 2024
6 changes: 5 additions & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def prod(iterable):
"lamb": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"ademamix": (
lib.cademamix32bit_grad_fp32,
Expand Down Expand Up @@ -96,10 +97,12 @@ def prod(iterable):
"momentum": (
lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16,
lib.cmomentum_8bit_blockwise_grad_bf16,
),
"rmsprop": (
lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16,
lib.crmsprop_8bit_blockwise_grad_bf16,
),
"lion": (
lib.clion_8bit_blockwise_grad_fp32,
Expand All @@ -109,6 +112,7 @@ def prod(iterable):
"adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
lib.cadagrad_8bit_blockwise_grad_bf16,
),
"ademamix": (
lib.cademamix_8bit_blockwise_grad_fp32,
Expand Down Expand Up @@ -398,7 +402,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.append(0)

data.sort()
return Tensor(data)
return torch.tensor(data)


def create_quantile_map(A, total_bits=8):
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/optim/ademamix.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def init_state(self, group, p, gindex, pindex):
self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device)

n = p.numel()
blocks = (n // 2048) + bool(n % 2048)
blocks = (n // 256) + bool(n % 256)

state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
Expand Down
8 changes: 4 additions & 4 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ def init_state(self, group, p, gindex, pindex):

if config["block_wise"]:
n = p.numel()
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
blocks = n // 256
blocks += 1 if n % 256 > 0 else 0

state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
Expand Down Expand Up @@ -699,8 +699,8 @@ def init_state(self, group, p, gindex, pindex):

if config["block_wise"]:
n = p.numel()
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
blocks = n // 256
blocks += 1 if n % 256 > 0 else 0

state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
else:
Expand Down
42 changes: 26 additions & 16 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3829,27 +3829,33 @@ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8

MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)

#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \

MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(MOMENTUM, __nv_bfloat16)
MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(RMSPROP, __nv_bfloat16)
MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(LION, __nv_bfloat16)
MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float)
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)

#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
Expand Down Expand Up @@ -3950,6 +3956,8 @@ MAKE_optimizerStatic8bit2State(ADAM, float)

template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
// template __global__ void kPercentileClipping<float, 128, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
// template __global__ void kPercentileClipping<half, 128, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);

#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
Expand Down Expand Up @@ -4041,13 +4049,12 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \

MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 2048, 8)

MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)

#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
Expand All @@ -4059,15 +4066,18 @@ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \

MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1)

template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);
14 changes: 10 additions & 4 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
}
}

#define BLOCKSIZE_2STATE 2048
#define NUM_2STATE 8
#define BLOCKSIZE_1STATE 2048
#define NUM_1STATE 8
#define BLOCKSIZE_2STATE 256
#define NUM_2STATE 1
#define BLOCKSIZE_1STATE 256
#define NUM_1STATE 1

template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
T* p,
Expand Down Expand Up @@ -818,13 +818,16 @@ MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(ADAM, __nv_bfloat16)
MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(MOMENTUM, __nv_bfloat16)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(RMSPROP, __nv_bfloat16)
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, __nv_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, half)
MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, float)
Expand Down Expand Up @@ -861,13 +864,16 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
Expand Down
14 changes: 10 additions & 4 deletions csrc/pythonInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,22 @@ void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\

MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(adam, ADAM, float, fp32)
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, half, fp16)
MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32)


Expand Down Expand Up @@ -283,13 +286,16 @@ extern "C"

MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
Expand Down
54 changes: 32 additions & 22 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,18 @@ def rm_path(path):
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["paged_ademamix_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
lambda pxx: bnb.optim.PagedAdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["ademamix8bit_blockwise_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)

str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
Expand Down Expand Up @@ -143,7 +151,7 @@ def rm_path(path):
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]

str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = str2statenames["paged_ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
("m1_m2", "state1", "qmap1", "absmax1"),
("nu", "state2", "qmap2", "absmax2"),
Expand All @@ -164,6 +172,7 @@ def rm_path(path):
"ademamix",
"ademamix_scheduled",
"paged_ademamix",
"paged_ademamix_scheduled",
]


Expand Down Expand Up @@ -309,18 +318,15 @@ def test_global_config(dim1, dim2, gtype):
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
torch.set_printoptions(precision=6)

if gtype == torch.bfloat16 and optim_name not in [
"adam8bit_blockwise",
"lion8bit_blockwise",
"ademamix8bit_blockwise",
]:
if gtype == torch.bfloat16 and "blockwise" not in optim_name:
pytest.skip()

if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 2048
blocksize = 256

torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
Expand All @@ -347,8 +353,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step()

# since Lion can have pretty noisy updates where things lie at the boundary
# and AdEMAMix can diverge as well, allow up to 0.05% errors.
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-4))
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
Expand Down Expand Up @@ -392,11 +397,11 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
err = torch.abs(p1 - p2)
relerr = err / (torch.abs(p1) + 1e-9)
if g.dtype == torch.bfloat16:
assert err.mean() < 0.00015
assert relerr.mean() < 0.0020 # 0.0016
assert err.mean() <= 0.00017
assert relerr.mean() <= 0.0016
else:
assert err.mean() < 0.00016 # 0.00012
assert relerr.mean() < 0.0016 # 0.0012
assert err.mean() < 0.00006
assert relerr.mean() < 0.0006

errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
Expand Down Expand Up @@ -454,9 +459,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):

num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
assert num_not_close.sum().item() < 20
# since Lion can have pretty noisy updates where things lie at the boundary
# and AdEMAMix can also be noisy, allow up to 0.05%.
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-04))

# Lion can have pretty noisy updates where things lie at the boundary
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
Expand Down Expand Up @@ -560,15 +565,19 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
optimizer_names_benchmark = [
"adam8bit_blockwise",
"paged_adam8bit_blockwise",
"paged_adamw8bit_blockwise",
"ademamix8bit_blockwise",
"paged_ademamix8bit_blockwise",
"ademamix8bit_blockwise_scheduled",
"paged_ademamix8bit_blockwise_scheduled",
"lion8bit_blockwise",
"paged_lion8bit_blockwise",
"paged_ademamix8bit_blockwise",
]


@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
Expand All @@ -580,8 +589,9 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):

g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g
for i in range(k):
if i == k // 5:
total_steps = 500
for i in range(total_steps):
if i == total_steps // 5:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
Expand All @@ -591,8 +601,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
torch.cuda.synchronize()
s = time.time() - t0
print("")
params = (k - k // 5) * dim1 * dim2
print(optim_name, gtype, s / params)
params = (total_steps - total_steps // 5) * dim1 * dim2
print(optim_name, gtype, s, params, s / params)
# assert s < 3.9


Expand Down
Loading