Skip to content

Commit

Permalink
Refactored simulated fp8 modules into research.nn.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Apr 12, 2023
1 parent e67bfcc commit dd562c2
Show file tree
Hide file tree
Showing 15 changed files with 108 additions and 272 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion bitsandbytes/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, SwitchBackLinearBnb
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear
176 changes: 1 addition & 175 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def forward(self, x):
return out


class Linear8bitLtMixed(nn.Linear):
class SwitchBackLinearBnb(nn.Linear):
def __init__(
self,
input_features,
Expand Down Expand Up @@ -355,177 +355,3 @@ def forward(self, x):
del self.state.CxB

return out


class Linear8bitLtThresh(Linear8bitLt):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=6.0,
index=None,
):
super().__init__(
input_features,
output_features,
bias=bias,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=6.,
index=index
)

class LinearFP8(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.bw_code = None
self.fw_code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
for i, k in enumerate(array):
if output_features > array[i + 1]:
self.bsz2 = k
break

def forward(self, x: torch.Tensor):
if self.fw_code is None:
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)

out = bnb.research.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
if self.bias is not None:
out += self.bias

return out

class LinearFP8Mixed(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.bw_code = None
self.fw_code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
for i, k in enumerate(array):
if output_features > array[i + 1]:
self.bsz2 = k
break

def forward(self, x: torch.Tensor):
if self.fw_code is None:
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)

out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
if self.bias is not None:
out += self.bias

return out

class LinearFP8Global(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.bw_code = None
self.fw_code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
for i, k in enumerate(array):
if output_features > array[i + 1]:
self.bsz2 = k
break

def forward(self, x: torch.Tensor):
if self.fw_code is None:
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)

out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
if self.bias is not None:
out += self.bias

return out

class LinearInt8(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
for i, k in enumerate(array):
if output_features > array[i + 1]:
self.bsz2 = k
break

def forward(self, x: torch.Tensor):
if self.code is None:
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)

out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz, bsz2=self.bsz2)
if self.bias is not None:
out += self.bias

return out

# This is 4 bit version.
class LinearInt8Cast(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break


def forward(self, x: torch.Tensor):
if self.code is None:
self.code = bnb.functional.create_linear_map(True, 4).to(x.device)

out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz)
if self.bias is not None:
out += self.bias

return out


class LinearFP4(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.bw_code = None
self.fw_code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
for i, k in enumerate(array):
if output_features > array[i + 1]:
self.bsz2 = k
break

def forward(self, x: torch.Tensor):
if self.fw_code is None:
#self.bw_code = bnb.functional.create_fp8_map(True, 3, 0, 4).to(x.device)
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 3, 0, 4).to(x.device)

out = bnb.matmul_fp4(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
if self.bias is not None:
out += self.bias

return out
3 changes: 1 addition & 2 deletions bitsandbytes/research/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

from . import nn
from .autograd._functions import (
matmul_fp8,
switchback_bnb,
matmul_fp8_global,
matmul_fp8_mixed,
Expand Down
98 changes: 8 additions & 90 deletions bitsandbytes/research/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,88 +16,6 @@ def prod(iterable):

tensor = torch.Tensor

class MatMulFP8(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

@staticmethod
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B

B_shape = B.shape
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)

# 1. Dequantize
# 2. MatmulnN
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)

cB, state = F.quantize(B.float(), code=fw_code)
fp8B = F.dequantize(cB, state).to(B.dtype)

output = torch.matmul(fp8A, fp8B)

# output is half

# 3. Save state
ctx.fw_code = fw_code
ctx.bw_code = bw_code
ctx.bsz = bsz
ctx.bsz2 = bsz2
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype

if any(ctx.needs_input_grad[:2]):
# NOTE: we send back A, and re-quant.
ctx.tensors = (A, fp8B)
else:
ctx.tensors = (None, None)

return output

@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None

req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
A, B = ctx.tensors

grad_A, grad_B = None, None

cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)

cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)

# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])

# not supported by PyTorch. TODO: create work-around
if req_gradA:
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)

if req_gradB:
if len(A.shape) == 3:
At = A.transpose(2, 1).contiguous()
else:
At = A.transpose(1, 0).contiguous()
cA, state = F.quantize(At.float(), code=ctx.fw_code)
fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)

return grad_A, grad_B, None, None, None, None, None

class MatMulFP8Mixed(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
Expand Down Expand Up @@ -171,7 +89,10 @@ def backward(ctx, grad_output):
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)

if req_gradB:
At = A.transpose(2, 1).contiguous()
if len(A.shape) == 3:
At = A.transpose(2, 1).contiguous()
else:
At = A.transpose(1, 0).contiguous()
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
# fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype)
Expand Down Expand Up @@ -252,7 +173,10 @@ def backward(ctx, grad_output):
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)

if req_gradB:
At = A.transpose(2, 1).contiguous()
if len(A.shape) == 3:
At = A.transpose(2, 1).contiguous()
else:
At = A.transpose(1, 0).contiguous()
cA, state = F.quantize(At.float(), code=ctx.fw_code)
fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
Expand Down Expand Up @@ -465,11 +389,6 @@ def get_block_sizes(input_matrix, weight_matrix):

return bsz, bsz2


def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz, bsz2)

def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
Expand All @@ -478,7 +397,6 @@ def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)


def switchback_bnb(
A: tensor,
B: tensor,
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/research/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modules import LinearFP8Mixed, LinearFP8Global
Loading

0 comments on commit dd562c2

Please sign in to comment.