Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AdEMAMix optimizer #1360

Merged
merged 5 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def prod(iterable):
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
),
"ademamix": (
lib.cademamix32bit_grad_fp32,
lib.cademamix32bit_grad_fp16,
lib.cademamix32bit_grad_bf16,
),
}

str2optimizer8bit = {
Expand Down Expand Up @@ -105,6 +110,11 @@ def prod(iterable):
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
),
"ademamix": (
lib.cademamix_8bit_blockwise_grad_fp32,
lib.cademamix_8bit_blockwise_grad_fp16,
lib.cademamix_8bit_blockwise_grad_bf16,
),
}


Expand Down Expand Up @@ -1550,6 +1560,8 @@ def optimizer_update_32bit(
lr: float,
state2: Optional[torch.Tensor] = None,
beta2: float = 0.0,
beta3: float = 0.0,
alpha: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1585,6 +1597,10 @@ def optimizer_update_32bit(
Optimizer state 2.
beta2 : float
Optimizer beta2.
beta3 : float
Optimizer beta3.
alpha : float
Optimizer alpha.
gnorm_scale : float
The factor to rescale the gradient to the max clip value.
unorm_vec : torch.Tensor
Expand Down Expand Up @@ -1623,6 +1639,8 @@ def optimizer_update_32bit(
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
Expand Down Expand Up @@ -1775,6 +1793,8 @@ def optimizer_update_8bit_blockwise(
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
Expand Down Expand Up @@ -1815,6 +1835,8 @@ def optimizer_update_8bit_blockwise(
get_ptr(state2),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PagedAdamW8bit,
PagedAdamW32bit,
)
from .ademamix import AdEMAMix, AdEMAMix8bit, PagedAdEMAMix, PagedAdEMAMix8bit
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
Expand Down
254 changes: 254 additions & 0 deletions bitsandbytes/optim/ademamix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
from typing import Iterable, Literal, Optional, Tuple

import torch

import bitsandbytes.functional as F
from bitsandbytes.optim.optimizer import Optimizer2State


class _ReferenceAdEMAMix(torch.optim.Optimizer):
"""
Reference: https://hf.co/papers/2409.03137
"""

def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
eps: float = 1e-8,
weight_decay: float = 1e-2, # default 0.0 or 1e-2?
t_beta3: Optional[int] = None,
t_alpha: Optional[int] = None,
):
defaults = dict(
lr=lr, betas=betas, alpha=alpha, eps=eps, weight_decay=weight_decay, t_beta3=t_beta3, t_alpha=t_alpha
)

super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
loss = None

if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if "step" in group:
group["step"] += 1
else:
group["step"] = 1

lr = group["lr"]
eps = group["eps"]
beta1, beta2, beta3_final = group["betas"]
alpha_final = group["alpha"]
weight_decay = group["weight_decay"]

for p in group["params"]:
if p.grad is None:
continue

grad = p.grad
state = self.state[p]

# State initialization
if len(state) == 0:
# For parity with bnb implementation we combine both fast
# and slow EMA stats into one stacked tensor.
state["m1_m2"] = p.new_zeros((2, *p.size()))
Comment on lines +63 to +65
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done for ease of compatibility with the existing test suite. In most other implementations we'll see two separate buffers here.

state["nu"] = torch.zeros_like(p) # second moment estimate

m1, m2, nu = state["m1_m2"][0], state["m1_m2"][1], state["nu"]

bias_correction1 = 1 - beta1 ** group["step"]

bias_correction2 = 1 - beta2 ** group["step"]

# TODO: Implement schedulers for alpha/beta3
beta3 = beta3_final
alpha = alpha_final

# Update the EMAs
m1.mul_(beta1).add_(grad, alpha=1 - beta1)
m2.mul_(beta3).add_(grad, alpha=1 - beta3)
nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

# Compute step
denom = (nu.sqrt() / (bias_correction2**0.5)).add(eps)
update = (m1.div(bias_correction1) + alpha * m2) / denom

# Add weight decay
update.add_(p, alpha=weight_decay)

# Apply update scaled by learning rate
p.add_(-lr * update)

return loss


class AdEMAMix(Optimizer2State):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
optim_bits: Literal[8, 32] = 32,
min_8bit_size: int = 4096,
is_paged: bool = False,
):
super().__init__(
"ademamix",
params=params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
optim_bits=optim_bits,
args=None,
min_8bit_size=min_8bit_size,
percentile_clipping=100,
block_wise=True,
is_paged=is_paged,
alpha=alpha,
)

@torch.no_grad()
def init_state(self, group, p, gindex, pindex):
# In our AdEMAMix implementation, we use `state` to hold
# both the fast and slow EMAs. Here we override the base
# `Optimizer2State` to allocate a buffer twice as large.
# Additional consideration: we do not support block_wise=False,
# percentile clipping, or max_unorm.

config = self.get_config(gindex, pindex, group)

if config["optim_bits"] == 32:
dtype = torch.float32
elif config["optim_bits"] == 8:
dtype = torch.uint8
else:
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')

if p.numel() < config["min_8bit_size"]:
dtype = torch.float32

state = self.state[p]
state["step"] = 0

if dtype == torch.uint8:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = state["qmap1"] = self.name2qmap["dynamic"].to(p.device)
self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device)

n = p.numel()
blocks = (n // 2048) + bool(n % 2048)

state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)

state["state1"] = self._get_state_double_buffer(p, dtype=dtype)
state["state2"] = self.get_state_buffer(p, dtype=dtype)

def _get_state_double_buffer(self, p, dtype=torch.float32):
if not self.is_paged or p.numel() < 0.5e5:
return torch.zeros((2, *p.size()), dtype=dtype, device=p.device)
else:
buff = F.get_paged(*(2, *p.size()), dtype=dtype, device=p.device)
F.fill(buff, 0)
self.page_mng.paged_tensors.append(buff)
return buff


class AdEMAMix8bit(AdEMAMix):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
min_8bit_size: int = 4096,
is_paged: bool = False,
):
super().__init__(
params,
lr=lr,
betas=betas,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
eps=eps,
weight_decay=weight_decay,
optim_bits=8,
min_8bit_size=min_8bit_size,
is_paged=is_paged,
)


class PagedAdEMAMix8bit(AdEMAMix8bit):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
min_8bit_size: int = 4096,
):
super().__init__(
params,
lr=lr,
betas=betas,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
eps=eps,
weight_decay=weight_decay,
min_8bit_size=min_8bit_size,
is_paged=True,
)


class PagedAdEMAMix(AdEMAMix):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
optim_bits: Literal[8, 32] = 32,
min_8bit_size: int = 4096,
):
super().__init__(
params,
lr=lr,
betas=betas,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
eps=eps,
weight_decay=weight_decay,
optim_bits=optim_bits,
min_8bit_size=min_8bit_size,
is_paged=True,
)
Loading
Loading