diff --git a/.github/scripts/set_platform_tag.py b/.github/scripts/set_platform_tag.py
index ca561c880..c82077074 100644
--- a/.github/scripts/set_platform_tag.py
+++ b/.github/scripts/set_platform_tag.py
@@ -7,9 +7,7 @@ def get_platform_tag(architecture):
     system = platform.system()
 
     if system == "Linux":
-        tag = (
-            "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
-        )
+        tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
     elif system == "Darwin":
         tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64"
     elif system == "Windows":
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c8ccfe8df..a859d05af 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,11 +1,11 @@
 repos:
   - repo: https://github.com/astral-sh/ruff-pre-commit
-    rev: v0.2.0
+    rev: v0.3.2
     hooks:
       - id: ruff
         args:
           - --fix
-      # - id: ruff-format  # TODO: enable when the time is right
+      - id: ruff-format
   - repo: https://github.com/pre-commit/pre-commit-hooks
     rev: v4.5.0
     hooks:
diff --git a/benchmarking/switchback/make_plot_with_jsonl.py b/benchmarking/switchback/make_plot_with_jsonl.py
index b23f63562..fd0dd7d58 100644
--- a/benchmarking/switchback/make_plot_with_jsonl.py
+++ b/benchmarking/switchback/make_plot_with_jsonl.py
@@ -1,13 +1,11 @@
-
 import matplotlib.gridspec as gridspec
 import matplotlib.pyplot as plt
 import pandas as pd
 
-cmap=plt.get_cmap('cool')
-
-if __name__ == '__main__':
+cmap = plt.get_cmap("cool")
 
-    fig = plt.figure(tight_layout=True, figsize=(12,3.5))
+if __name__ == "__main__":
+    fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
     gs = gridspec.GridSpec(1, 2)
 
     dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
@@ -19,25 +17,28 @@
     ax = fig.add_subplot(gs[0, 0])
 
     # TODO: change this to what you want.
-    rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
+    rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True)
     df = rdf[rdf.batch_size == batch_size_for_plot1]
 
     # first plot the time occupied by different operations
     for k, marker, ls, color, name in [
-        ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
-        ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
-
-        ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
-        ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
-        ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
-
-        ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
-        ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
-
-        ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
-        ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
-        ('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'),
-        ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'),
+        ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"),
+        (
+            "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
+            "o",
+            "-",
+            "C4",
+            "SwitchBack int8 (sum of parts)",
+        ),
+        ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"),
+        ("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"),
+        ("standard_gx", "^", ":", "gray", "Matmul GX (both)"),
+        ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"),
+        ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"),
+        ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"),
+        ("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"),
+        ("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"),
+        ("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"),
     ]:
         xs = []
         ys = []
@@ -47,40 +48,46 @@
             df_ = df_[df_.dim_out == embed_dim * 4]
             xs.append(embed_dim)
             y_ = 0
-            for k_ in k.split('+'):
+            for k_ in k.split("+"):
                 y_ += df_[k_].values[0]
             df_ = df[df.dim_in == embed_dim * 4]
             df_ = df_[df_.dim_out == embed_dim]
-            for k_ in k.split('+'):
+            for k_ in k.split("+"):
                 y_ += df_[k_].values[0]
             ys.append(y_ * 0.5)
 
+        ax.plot(
+            xs,
+            ys,
+            color=color,
+            label=name,
+            marker=marker,
+            markersize=5 if marker == "s" else 5,
+            linestyle=ls,
+            linewidth=2 if "+" in k else 1.0,
+        )
 
-        ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
-
-
-    ax.set_xlabel('dim', fontsize=13)
-    ax.set_ylabel('time (ms)', fontsize=13)
+    ax.set_xlabel("dim", fontsize=13)
+    ax.set_ylabel("time (ms)", fontsize=13)
 
     ax.grid()
 
-    ax.set_xscale('log')
+    ax.set_xscale("log")
     if logscale_plot1:
-        ax.set_yscale('log')
+        ax.set_yscale("log")
 
-    ax.tick_params(axis='x', labelsize=11)
-    ax.tick_params(axis='y', labelsize=11)
+    ax.tick_params(axis="x", labelsize=11)
+    ax.tick_params(axis="y", labelsize=11)
 
     ax.set_xticks(dims_to_xtick)
     ax.set_xticklabels(dims_to_xtick)
     ax.set_xticks([], minor=True)
 
-    leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64,  1.), ncol=1, fontsize=10)
-    leg.get_texts()[0].set_fontweight('bold')
-    leg.get_texts()[1].set_fontweight('bold')
+    leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10)
+    leg.get_texts()[0].set_fontweight("bold")
+    leg.get_texts()[1].set_fontweight("bold")
     plt.subplots_adjust(left=0.1)
-    ax.set_title('  Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
-
+    ax.set_title("  Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20)
 
     ax = fig.add_subplot(gs[0, 1])
 
@@ -88,10 +95,15 @@
     for j, batch_size in enumerate(batch_sizes_for_plot2):
         all_xs, all_ys = [], []
         for k, marker, ls, color, name in [
-            ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
-            ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
+            ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"),
+            (
+                "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
+                "o",
+                "-",
+                "C4",
+                "SwitchBack int8 (total time)",
+            ),
         ]:
-
             xs, ys = [], []
             df = rdf[rdf.batch_size == batch_size]
             for embed_dim in dims_to_consider:
@@ -99,11 +111,11 @@
                 df_ = df_[df_.dim_out == embed_dim * 4]
                 xs.append(embed_dim)
                 y_ = 0
-                for k_ in k.split('+'):
+                for k_ in k.split("+"):
                     y_ += df_[k_].values[0]
                 df_ = df[df.dim_in == embed_dim * 4]
                 df_ = df_[df_.dim_out == embed_dim]
-                for k_ in k.split('+'):
+                for k_ in k.split("+"):
                     y_ += df_[k_].values[0]
                 ys.append(y_ * 0.5)
             all_xs.append(xs)
@@ -111,25 +123,29 @@
 
         color = cmap(j * 0.25)
         real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
-        markers = ['^', 'v', 'P', 'o']
-        ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
+        markers = ["^", "v", "P", "o"]
+        ax.plot(
+            all_xs[0],
+            real_ys,
+            color=color,
+            label=f"batch * sequence length = {batch_size}",
+            marker=markers[j],
+            markersize=5 if marker == "s" else 5,
+        )
 
     ax.legend()
-    ax.set_xlabel('dim', fontsize=13)
-    ax.set_xscale('log')
+    ax.set_xlabel("dim", fontsize=13)
+    ax.set_xscale("log")
     ax.grid()
-    ax.set_ylabel(r'% speedup', fontsize=13)
+    ax.set_ylabel(r"% speedup", fontsize=13)
 
-
-    ax.tick_params(axis='x', labelsize=11)
-    ax.tick_params(axis='y', labelsize=11)
+    ax.tick_params(axis="x", labelsize=11)
+    ax.tick_params(axis="y", labelsize=11)
 
     ax.set_xticks(dims_to_xtick)
     ax.set_xticklabels(dims_to_xtick)
     ax.set_xticks([], minor=True)
 
-    ax.set_title('  Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
-
-
+    ax.set_title("  Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20)
 
-    plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
+    plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")
diff --git a/benchmarking/switchback/speed_benchmark.py b/benchmarking/switchback/speed_benchmark.py
index c4f3cd4c6..eaba0e9cd 100644
--- a/benchmarking/switchback/speed_benchmark.py
+++ b/benchmarking/switchback/speed_benchmark.py
@@ -20,15 +20,15 @@
 
 # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
 
-def get_time(k, fn, info_dict):
 
+def get_time(k, fn, info_dict):
     for _ in range(repeat // 2):
-       fn()
+        fn()
 
     torch.cuda.synchronize()
     start = time.time()
     for _ in range(repeat):
-       fn()
+        fn()
 
     torch.cuda.synchronize()
     end = time.time()
@@ -36,16 +36,15 @@ def get_time(k, fn, info_dict):
     print(f"time {k}: {ms:.3f} ms")
     info_dict[k] = ms
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     torch.manual_seed(0)
     wm = 4
     for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
         # note "batch_size" is actually "batch_size * embed_dim", which is why it's large
-        for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
-
+        for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]:
             # switch switches dim_in and dim_out
             for switch in [False, True]:
-
                 # hparams
                 repeat = 64
                 batch_size = batch_size
@@ -73,35 +72,86 @@ def get_time(k, fn, info_dict):
                 state_w_rowwise = w.max(dim=1)[0]
                 state_w_global = w.max()
 
-                info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
-
-                get_time('standard_fwd', lambda : x.matmul(w.t()), info)
-                get_time('standard_gw', lambda : g.t().matmul(x), info)
-                get_time('standard_gx', lambda : g.matmul(w), info)
-                get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
-                get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
-                get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
-                get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
-                get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
-                get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
-                get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
-                get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
-                get_time('w_quantize_global', lambda : quantize_global(w), info)
-                get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
-
-                time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
-                time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise']  + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
-                time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
-
-                print('TOTAL STANDARD', time_standard)
-                print('TOTAL ROWWISE', time_rowwise)
-                print('TOTAL GLOBAL', time_global)
-
-                print('speedup', -100*(time_global - time_standard)/time_standard)
-
-                info['time_standard'] = time_standard
-                info['time_rowwise'] = time_rowwise
-                info['time_global'] = time_global
+                info = {
+                    "repeat": repeat,
+                    "batch_size": batch_size,
+                    "dim_out": dim_out,
+                    "dim_in": dim_in,
+                    "wm": wm,
+                    "switch": switch,
+                }
+
+                get_time("standard_fwd", lambda: x.matmul(w.t()), info)
+                get_time("standard_gw", lambda: g.t().matmul(x), info)
+                get_time("standard_gx", lambda: g.matmul(w), info)
+                get_time(
+                    "rowwise_fwd",
+                    lambda: int8_matmul_rowwise_dequantize(
+                        x_int8,
+                        w_int8.t(),
+                        state_x_rowwise,
+                        state_w_columnwise,
+                        None,
+                    ),
+                    info,
+                )
+                get_time(
+                    "rowwise_bwd",
+                    lambda: int8_matmul_rowwise_dequantize(
+                        g_int8,
+                        wt_int8.t(),
+                        state_x_rowwise,
+                        state_w_rowwise,
+                        None,
+                    ),
+                    info,
+                )
+                get_time(
+                    "global_fwd",
+                    lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None),
+                    info,
+                )
+                get_time(
+                    "global_bwd",
+                    lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None),
+                    info,
+                )
+                get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info)
+                get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info)
+                get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info)
+                get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info)
+                get_time("w_quantize_global", lambda: quantize_global(w), info)
+                get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info)
+
+                time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"]
+                time_rowwise = (
+                    info["x_quantize_rowwise"]
+                    + info["g_quantize_rowwise"]
+                    + info["w_quantize_colwise_transpose"]
+                    + info["w_quantize_rowwise"]
+                    + info["standard_gw"]
+                    + info["rowwise_fwd"]
+                    + info["rowwise_bwd"]
+                )
+                time_global = (
+                    info["x_quantize_rowwise"]
+                    + info["g_quantize_rowwise"]
+                    + info["w_quantize_global"]
+                    + info["w_quantize_global_transpose"]
+                    + info["standard_gw"]
+                    + info["global_fwd"]
+                    + info["global_bwd"]
+                )
+
+                print("TOTAL STANDARD", time_standard)
+                print("TOTAL ROWWISE", time_rowwise)
+                print("TOTAL GLOBAL", time_global)
+
+                print("speedup", -100 * (time_global - time_standard) / time_standard)
+
+                info["time_standard"] = time_standard
+                info["time_rowwise"] = time_rowwise
+                info["time_global"] = time_global
 
                 info_json = json.dumps(info)
 
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 6cbb6efd9..e9821cd36 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -14,16 +14,18 @@
 def prod(iterable):
     return reduce(operator.mul, iterable, 1)
 
+
 # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
 # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
 
 
-
 """
     This class pools outlier dimensions across layers.
     This is particularly important for small models where outlier features
     are less systematic and occur with low frequency.
 """
+
+
 class GlobalOutlierPooler:
     _instance = None
 
@@ -83,6 +85,7 @@ def get_inverse_transform_indices(
             break  # if all indices fit in i bytes, stop early
     return permuted_tile_indices
 
+
 def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
     """
     Undo a tiled permutation such as turing or ampere layout
@@ -159,20 +162,12 @@ def backward(ctx, grad_output):
                     )
                     if not A.is_contiguous():
                         A = A.contiguous()
-                    qA, S2 = F.vectorwise_quant(
-                        A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
-                    )
+                    qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
                     igrad_B = F.igemm(qA.t(), qgrad_output)
-                    grad_B = F.vectorwise_mm_dequant(
-                        igrad_B, S2.t(), S1, grad_output.dtype, quant_type
-                    )
+                    grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
                 else:
-                    qgrad_output, S1 = F.vectorwise_quant(
-                        grad_output, dim=dims, quant_type=quant_type
-                    )
-                    qA, S2 = F.vectorwise_quant(
-                        A, dim=dims, quant_type=quant_type
-                    )
+                    qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+                    qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
                     igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
                     grad_B = F.vectorwise_mm_dequant(
                         igrad_B,
@@ -201,9 +196,7 @@ def backward(ctx, grad_output):
                 with torch.no_grad():
                     grad_A = torch.matmul(grad_output, B.permute(permute_dim))
             else:
-                qgrad_output, S1 = F.vectorwise_quant(
-                    grad_output, dim=dims, quant_type=quant_type
-                )
+                qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
                 qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
                 igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
                 grad_A = F.vectorwise_mm_dequant(
@@ -227,7 +220,7 @@ def supports_igemmlt(device: torch.device) -> bool:
     if torch.cuda.get_device_capability(device=device) < (7, 5):
         return False
     device_name = torch.cuda.get_device_name(device=device)
-    nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660')  # https://en.wikipedia.org/wiki/GeForce_16_series
+    nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660")  # https://en.wikipedia.org/wiki/GeForce_16_series
     if any(model_name in device_name for model_name in nvidia16_models):
         return False  # these devices are technically cuda 7.5-capable, but they lack tensor cores
     return True
@@ -246,6 +239,7 @@ def get_tile_inds(format, device):
     with torch.no_grad():
         return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
 
+
 @dataclass
 class MatmulLtState:
     _tile_indices: Optional[torch.Tensor] = None
@@ -510,7 +504,6 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
             else:
                 return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
 
-
         # 1. Dequantize
         # 2. MatmulnN
         output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
@@ -532,7 +525,7 @@ def backward(ctx, grad_output):
             bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
             return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
 
-        req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
+        req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
         A, B = ctx.tensors
 
         grad_A, grad_B, grad_bias = None, None, None
@@ -542,8 +535,9 @@ def backward(ctx, grad_output):
             grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
 
         # not supported by PyTorch. TODO: create work-around
-        #if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
-        if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
+        # if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
+        if req_gradA:
+            grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
 
         return grad_A, grad_B, None, grad_bias, None
 
@@ -554,7 +548,7 @@ def matmul(
     out: Optional[torch.Tensor] = None,
     state: Optional[MatmulLtState] = None,
     threshold=0.0,
-    bias=None
+    bias=None,
 ):
     state = state or MatmulLtState()
     if threshold > 0.0:
@@ -562,11 +556,19 @@ def matmul(
     return MatMul8bitLt.apply(A, B, out, bias, state)
 
 
-def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None):
+def matmul_4bit(
+    A: torch.Tensor,
+    B: torch.Tensor,
+    quant_state: F.QuantState,
+    out: Optional[torch.Tensor] = None,
+    bias=None,
+):
     assert quant_state is not None
     if A.numel() == A.shape[-1] and A.requires_grad == False:
         if A.shape[-1] % quant_state.blocksize != 0:
-            warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
+            warn(
+                f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
+            )
             return MatMul4Bit.apply(A, B, out, bias, quant_state)
         else:
             out = F.gemv_4bit(A, B.t(), out, state=quant_state)
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index 57ba71020..c8ae7358d 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -56,7 +56,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
             "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
             "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
             "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n"
-            "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n"
+            "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n",
         )
 
     return PACKAGE_DIR / library_name
@@ -100,7 +100,7 @@ def get_native_library() -> BNBNativeLibrary:
 
     logger.warning(
         "The installed version of bitsandbytes was compiled without GPU support. "
-        "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable."
+        "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.",
     )
     return BNBNativeLibrary(dll)
 
@@ -120,5 +120,5 @@ def get_native_library() -> BNBNativeLibrary:
 Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
 to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
 and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
-"""
+""",
         )
diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py
index d65f80d8b..f993dff7e 100644
--- a/bitsandbytes/diagnostics/cuda.py
+++ b/bitsandbytes/diagnostics/cuda.py
@@ -120,7 +120,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
 
         The CUDA version for the compile might depend on your conda install, if using conda.
         Inspect CUDA version via `conda list | grep cuda`.
-        """
+        """,
         )
 
     cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
@@ -129,7 +129,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
             """
             WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
             You will be only to use 8-bit optimizers and quantization routines!
-            """
+            """,
         )
 
     print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
@@ -170,7 +170,7 @@ def print_cuda_runtime_diagnostics() -> None:
 
             In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
             export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
-            """
+            """,
         )
         for pth in cudart_paths:
             print(f"* Found CUDA runtime at: {pth}")
diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py
index 7a88bca26..1ce096f69 100644
--- a/bitsandbytes/diagnostics/main.py
+++ b/bitsandbytes/diagnostics/main.py
@@ -25,7 +25,7 @@ def sanity_check():
             See the documentation for more details if needed.
 
             Trying a simple check anyway, but this will likely fail...
-            """
+            """,
         )
 
     from bitsandbytes.optim import Adam
@@ -71,7 +71,7 @@ def main():
         print(
             f"WARNING: {__package__} is currently running as CPU-only!\n"
             "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
-            f"If you think that this is so erroneously,\nplease report an issue!"
+            f"If you think that this is so erroneously,\nplease report an issue!",
         )
     except Exception:
         traceback.print_exc()
@@ -80,6 +80,6 @@ def main():
         Above we output some debug information.
         Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose
         WARNING: Please be sure to sanitize sensitive info from the output before posting it.
-        """
+        """,
     )
     sys.exit(1)
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 61d0d83b2..8fa8f2f60 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -21,6 +21,7 @@
 def prod(iterable):
     return reduce(operator.mul, iterable, 1)
 
+
 name2qmap = {}
 
 if lib and lib.compiled_with_cuda:
@@ -127,7 +128,6 @@ def prefetch_all(self, to_cpu=False):
             prefetch_tensor(t, to_cpu)
 
 
-
 class CUBLAS_Context:
     _instance = None
 
@@ -169,6 +169,7 @@ def get_instance(cls):
             cls._instance.initialize()
         return cls._instance
 
+
 dtype2bytes = {}
 dtype2bytes[torch.float32] = 4
 dtype2bytes[torch.float16] = 2
@@ -176,10 +177,11 @@ def get_instance(cls):
 dtype2bytes[torch.uint8] = 1
 dtype2bytes[torch.int8] = 1
 
-FIRST_CUDA_DEVICE = torch.device('cuda', index=0)
+FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
+
 
 def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
-    num_bytes = dtype2bytes[dtype]*prod(shape)
+    num_bytes = dtype2bytes[dtype] * prod(shape)
     cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
     c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
     new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
@@ -188,31 +190,35 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
     out.page_deviceid = device.index
     return out
 
+
 def prefetch_tensor(A, to_cpu=False):
-    assert A.is_paged, 'Only paged tensors can be prefetched!'
+    assert A.is_paged, "Only paged tensors can be prefetched!"
     if to_cpu:
         deviceid = -1
     else:
         deviceid = A.page_deviceid
 
-    num_bytes = dtype2bytes[A.dtype]*A.numel()
+    num_bytes = dtype2bytes[A.dtype] * A.numel()
     lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
 
+
 def elementwise_func(func_name, A, B, value, prefetch=True):
     func = None
     if A.dtype == torch.float32:
-        func = getattr(lib, f'c{func_name}_fp32', None)
+        func = getattr(lib, f"c{func_name}_fp32", None)
         cvalue = ct.c_float(value)
     elif A.dtype == torch.uint8:
-        func = getattr(lib, f'c{func_name}_uint8', None)
+        func = getattr(lib, f"c{func_name}_uint8", None)
         cvalue = ct.c_uint8(value)
 
-    if func is None: raise NotImplementedError(f'Function not implemented: {func_name}')
+    if func is None:
+        raise NotImplementedError(f"Function not implemented: {func_name}")
 
-    is_managed = getattr(A, 'is_managed', False)
+    is_managed = getattr(A, "is_managed", False)
     if is_managed and prefetch:
         prefetch_tensor(A)
-        if B is not None: prefetch_tensor(B)
+        if B is not None:
+            prefetch_tensor(B)
 
     func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel()))
     if A.is_paged or B.is_paged:
@@ -222,28 +228,36 @@ def elementwise_func(func_name, A, B, value, prefetch=True):
         # operation occurred. So we synchronize.
         torch.cuda.synchronize()
 
-def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value)
-def arange(A, device=None): elementwise_func('arange', A, None, 0)
-def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0)
+
+def fill(A, value, device=None, prefetch=True):
+    elementwise_func("fill", A, None, value)
+
+
+def arange(A, device=None):
+    elementwise_func("arange", A, None, 0)
+
+
+def _mul(A, B, device=None):
+    elementwise_func("_mul", A, B, 0)
 
 
 def create_linear_map(signed=True, total_bits=8, add_zero=True):
-    sign = (-1.0 if signed else 0.0)
+    sign = -1.0 if signed else 0.0
     total_values = 2**total_bits
     if add_zero or total_bits < 8:
         # add a zero
         # since we simulate less bits by having zeros in the data type, we
         # we need to center the quantization around zero and as such lose
         # a single value
-        total_values = (2**total_bits if not signed else 2**total_bits-1)
+        total_values = 2**total_bits if not signed else 2**total_bits - 1
 
     values = torch.linspace(sign, 1.0, total_values)
     gap = 256 - values.numel()
     if gap == 0:
         return values
     else:
-        l = values.numel()//2  # noqa: E741
-        return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
+        l = values.numel() // 2  # noqa: E741
+        return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist())
 
 
 def create_normal_map(offset=0.9677083, use_extra_value=True):
@@ -251,18 +265,17 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
         from scipy.stats import norm
     except ImportError as ie:
         raise ImportError(
-            "Scipy is required for `create_normal_map`. "
-            "Install `bitsandbytes` with the `[test]` extra."
+            "Scipy is required for `create_normal_map`. Install `bitsandbytes` with the `[test]` extra.",
         ) from ie
 
     if use_extra_value:
         # one more positive value, this is an asymmetric type
         v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
-        v2 = [0]*(256-15) ## we have 15 non-zero values in this data type
+        v2 = [0] * (256 - 15)  ## we have 15 non-zero values in this data type
         v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
     else:
         v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
-        v2 = [0]*(256-14) ## we have 14 non-zero values in this data type
+        v2 = [0] * (256 - 14)  ## we have 14 non-zero values in this data type
         v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
 
     v = v1 + v2 + v3
@@ -275,38 +288,37 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
 
     return values
 
+
 def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
     e = exponent_bits
     p = precision_bits
     has_sign = 1 if signed else 0
-    assert e+p == total_bits-has_sign
+    assert e + p == total_bits - has_sign
     # the exponent is biased to 2^(e-1) -1 == 0
     evalues = []
     pvalues = []
-    for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)):
+    for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)):
         evalues.append(2**val)
 
-
     values = []
     lst = list(itertools.product([0, 1], repeat=precision_bits))
-    #for ev in evalues:
-    bias = 2**(exponent_bits-1)
-    for evalue in range(2**(exponent_bits)):
+    # for ev in evalues:
+    bias = 2 ** (exponent_bits - 1)
+    for evalue in range(2 ** (exponent_bits)):
         for bit_pattern in lst:
-            value = (1 if evalue != 0 else 0)
+            value = 1 if evalue != 0 else 0
             for i, pval in enumerate(list(bit_pattern)):
-                value += pval*(2**-(i+1))
+                value += pval * (2 ** -(i + 1))
             if evalue == 0:
                 # subnormals
-                value = value*2**-(bias)
+                value = value * 2**-(bias)
             else:
                 # normals
-                value = value*2**-(evalue-bias-1)
+                value = value * 2 ** -(evalue - bias - 1)
             values.append(value)
             if signed:
                 values.append(-value)
 
-
     assert len(values) == 2**total_bits
     values.sort()
     if total_bits < 8:
@@ -320,7 +332,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
     return code
 
 
-
 def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
     """
     Creates the dynamic quantiztion map.
@@ -345,7 +356,11 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
     non_sign_bits = total_bits - (1 if signed else 1)
     additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
     for i in range(max_exponent_bits):
-        fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)
+        fraction_items = int(
+            2 ** (i + non_sign_bits - max_exponent_bits) + 1
+            if signed
+            else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1,
+        )
         boundaries = torch.linspace(0.1, 1, fraction_items)
         means = (boundaries[:-1] + boundaries[1:]) / 2.0
         data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
@@ -371,8 +386,9 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
     data.sort()
     return Tensor(data)
 
+
 def create_quantile_map(A, total_bits=8):
-    q = estimate_quantiles(A, num_quantiles=2**total_bits-1)
+    q = estimate_quantiles(A, num_quantiles=2**total_bits - 1)
     q = q.tolist()
     q.append(0)
 
@@ -383,11 +399,13 @@ def create_quantile_map(A, total_bits=8):
     q.sort()
 
     q = Tensor(q)
-    q = q/q.abs().max()
+    q = q / q.abs().max()
     return q
 
+
 def get_special_format_str():
-    if not torch.cuda.is_available(): return 'col_turing'
+    if not torch.cuda.is_available():
+        return "col_turing"
     major, _minor = torch.cuda.get_device_capability()
     if major <= 7:
         return "col_turing"
@@ -396,20 +414,24 @@ def get_special_format_str():
     return "col_turing"
 
 
-
 def is_on_gpu(tensors):
     on_gpu = True
     gpu_ids = set()
     for t in tensors:
-        if t is None: continue # NULL pointers are fine
-        is_paged = getattr(t, 'is_paged', False)
-        on_gpu &= (t.device.type == 'cuda' or is_paged)
+        if t is None:
+            continue  # NULL pointers are fine
+        is_paged = getattr(t, "is_paged", False)
+        on_gpu &= t.device.type == "cuda" or is_paged
         if not is_paged:
             gpu_ids.add(t.device.index)
     if not on_gpu:
-        raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}')
+        raise TypeError(
+            f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}",
+        )
     if len(gpu_ids) > 1:
-        raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}')
+        raise TypeError(
+            f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}",
+        )
     return on_gpu
 
 
@@ -447,15 +469,13 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
     if not hasattr(lib, name):
         print(name)
         raise ValueError(
-            f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}"
+            f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}",
         )
     else:
         return getattr(lib, name)
 
 
-def get_transform_buffer(
-    shape, dtype, device, to_order, from_order="row", transpose=False
-):
+def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
     # init_func = torch.empty
     init_func = torch.zeros
     dims = len(shape)
@@ -508,9 +528,7 @@ def nvidia_transform(
     else:
         from_order = state[1]
     if out is None:
-        out, new_state = get_transform_buffer(
-            state[0], A.dtype, A.device, to_order, state[1]
-        )
+        out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
     else:
         new_state = (state[1], to_order)
     func = get_transform_func(A.dtype, from_order, to_order, transpose)
@@ -534,8 +552,13 @@ def nvidia_transform(
     return out, new_state
 
 
-def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
-    '''
+def estimate_quantiles(
+    A: Tensor,
+    out: Optional[torch.Tensor] = None,
+    offset: float = 1 / 512,
+    num_quantiles=256,
+) -> Tensor:
+    """
     Estimates 256 equidistant quantiles on the input tensor eCDF.
 
     Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
@@ -562,14 +585,21 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl
     -------
     torch.Tensor:
         The 256 quantiles in float32 datatype.
-    '''
-    if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
-    if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
-    if num_quantiles < 256 and offset == 1/(512):
+    """
+    if A.numel() < 256:
+        raise NotImplementedError(
+            f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.",
+        )
+    if num_quantiles > 256:
+        raise NotImplementedError(
+            f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}",
+        )
+    if num_quantiles < 256 and offset == 1 / (512):
         # override default arguments
-        offset = 1/(2*num_quantiles)
+        offset = 1 / (2 * num_quantiles)
 
-    if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
+    if out is None:
+        out = torch.zeros((256,), dtype=torch.float32, device=A.device)
     is_on_gpu([A, out])
     device = pre_call(A.device)
     if A.dtype == torch.float32:
@@ -581,7 +611,7 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl
     post_call(device)
 
     if num_quantiles < 256:
-        step = round(256/num_quantiles)
+        step = round(256 / num_quantiles)
         idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
         out = out[idx]
 
@@ -590,12 +620,35 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl
 
 class QuantState:
     """container for quantization state components to work with Params4bit and similar classes"""
-    valid_quant_types = ('fp4', 'nf4')
-    valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
-    valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type',
-                     'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
 
-    def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
+    valid_quant_types = ("fp4", "nf4")
+    valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
+    valid_qs_keys = [
+        "absmax",
+        "quant_map",
+        "nested_absmax",
+        "nested_quant_map",
+        "quant_state",
+        "quant_type",
+        "blocksize",
+        "dtype",
+        "shape",
+        "nested_blocksize",
+        "nested_dtype",
+        "nested_offset",
+    ]
+
+    def __init__(
+        self,
+        absmax,
+        shape=None,
+        code=None,
+        blocksize=None,
+        quant_type=None,
+        dtype=None,
+        offset=None,
+        state2=None,
+    ):
         self.absmax = absmax
         self.shape = shape
         self.code = code
@@ -614,13 +667,20 @@ def __get_item__(self, idx):
         state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
         """
         if self.nested:
-            list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type]
+            list_repr = [
+                self.absmax,
+                self.shape,
+                self.dtype,
+                self.blocksize,
+                [self.offset, self.state2],
+                self.quant_type,
+            ]
         else:
             list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
         return list_repr[idx]
 
     @classmethod
-    def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState':
+    def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState":
         """
         unpacks components of state_dict into QuantState
         where necessary, convert into strings, torch.dtype, ints, etc.
@@ -632,37 +692,39 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState
 
         # unpacking tensor with non-tensor components
         qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
-        if not len(qs_key) and 'quant_type' not in qs_dict:
+        if not len(qs_key) and "quant_type" not in qs_dict:
             raise ValueError("Expected packed or unpacked quant_state items, found neither")
         elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
-            raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.")
+            raise ValueError(
+                f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
+            )
 
         # unpacking minor and non-tensor quant state items if necessary
         if len(qs_key) == 1:
             first_qs_key = qs_key[0]
             qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key)))
 
-        qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()}  # strip prefixes
+        qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()}  # strip prefixes
         assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
 
-        if 'nested_absmax' in qs_dict:
-            offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
+        if "nested_absmax" in qs_dict:
+            offset = torch.tensor(float(qs_dict["nested_offset"])).to(device)
             state2 = cls(
-                absmax=qs_dict['nested_absmax'].to(device),
-                blocksize=qs_dict['nested_blocksize'],
-                code=qs_dict['nested_quant_map'].to(device),
-                dtype=getattr(torch, qs_dict['nested_dtype']),
+                absmax=qs_dict["nested_absmax"].to(device),
+                blocksize=qs_dict["nested_blocksize"],
+                code=qs_dict["nested_quant_map"].to(device),
+                dtype=getattr(torch, qs_dict["nested_dtype"]),
             )
         else:
             offset, state2 = None, None
 
         quant_state = cls(
-            quant_type=qs_dict['quant_type'],
-            absmax=qs_dict['absmax'].to(device),
-            blocksize=qs_dict['blocksize'],
-            code=qs_dict['quant_map'].to(device),
-            dtype=getattr(torch, qs_dict['dtype']),
-            shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None,
+            quant_type=qs_dict["quant_type"],
+            absmax=qs_dict["absmax"].to(device),
+            blocksize=qs_dict["blocksize"],
+            code=qs_dict["quant_map"].to(device),
+            dtype=getattr(torch, qs_dict["dtype"]),
+            shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None,
             offset=offset,
             state2=state2,
         )
@@ -674,21 +736,23 @@ def as_dict(self, packed=False):
         param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
         """
         qs_dict = {
-            'quant_type': self.quant_type,
-            'absmax': self.absmax,
-            'blocksize': self.blocksize,
-            'quant_map': self.code,
-            'dtype': str(self.dtype).strip('torch.'),
-            'shape': tuple(self.shape),
+            "quant_type": self.quant_type,
+            "absmax": self.absmax,
+            "blocksize": self.blocksize,
+            "quant_map": self.code,
+            "dtype": str(self.dtype).strip("torch."),
+            "shape": tuple(self.shape),
         }
         if self.nested:
-            qs_dict.update({
-                'nested_absmax': self.state2.absmax,
-                'nested_blocksize': self.state2.blocksize,
-                'nested_quant_map': self.state2.code.clone(),  # un-shared to avoid restoring it after shared tensors are removed by safetensors
-                'nested_dtype': str(self.state2.dtype).strip('torch.'),
-                'nested_offset': self.offset.item(),
-            })
+            qs_dict.update(
+                {
+                    "nested_absmax": self.state2.absmax,
+                    "nested_blocksize": self.state2.blocksize,
+                    "nested_quant_map": self.state2.code.clone(),  # un-shared to avoid restoring it after shared tensors are removed by safetensors
+                    "nested_dtype": str(self.state2.dtype).strip("torch."),
+                    "nested_offset": self.offset.item(),
+                },
+            )
         if not packed:
             return qs_dict
 
@@ -711,14 +775,22 @@ def __eq__(self, other):
             return False
 
         return (
-            torch.allclose(self.absmax, other.absmax, atol=1e-6) and
-            self.shape == other.shape and
-            torch.allclose(self.code, other.code, atol=1e-6) and
-            self.dtype == other.dtype and
-            self.blocksize == other.blocksize and
-            self.quant_type == other.quant_type and
-            (self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and
-            (self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2)
+            torch.allclose(self.absmax, other.absmax, atol=1e-6)
+            and self.shape == other.shape
+            and torch.allclose(self.code, other.code, atol=1e-6)
+            and self.dtype == other.dtype
+            and self.blocksize == other.blocksize
+            and self.quant_type == other.quant_type
+            and (
+                self.offset == other.offset
+                if self.offset is not None and other.offset is not None
+                else self.offset is other.offset
+            )
+            and (
+                self.state2 == other.state2
+                if self.state2 is not None and other.state2 is not None
+                else self.state2 is other.state2
+            )
         )
 
 
@@ -756,7 +828,6 @@ def quantize_blockwise(
         The quantization state to undo the quantization.
     """
 
-
     if code is None:
         if "dynamic" not in name2qmap:
             name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@@ -771,31 +842,66 @@ def quantize_blockwise(
     if out is None:
         out = torch.zeros_like(A, dtype=torch.uint8)
 
-    if A.device.type != 'cpu':
+    if A.device.type != "cpu":
         assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
         cblocksize = ct.c_int32(blocksize)
         prev_device = pre_call(A.device)
         code = code.to(A.device)
         is_on_gpu([code, A, out, absmax])
         if A.dtype == torch.float32:
-            lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
+            lib.cquantize_blockwise_fp32(
+                get_ptr(code),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                cblocksize,
+                ct.c_int(A.numel()),
+            )
         elif A.dtype == torch.float16:
-            lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
+            lib.cquantize_blockwise_fp16(
+                get_ptr(code),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                cblocksize,
+                ct.c_int(A.numel()),
+            )
         elif A.dtype == torch.bfloat16:
-            lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
+            lib.cquantize_blockwise_bf16(
+                get_ptr(code),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                cblocksize,
+                ct.c_int(A.numel()),
+            )
         else:
             raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
         post_call(A.device)
     else:
         # cpu
         code = code.cpu()
-        lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
+        lib.cquantize_blockwise_cpu_fp32(
+            get_ptr(code),
+            get_ptr(A),
+            get_ptr(absmax),
+            get_ptr(out),
+            ct.c_longlong(blocksize),
+            ct.c_longlong(A.numel()),
+        )
 
     if nested:
         offset = absmax.mean()
         absmax -= offset
         qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
-        quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2)
+        quant_state = QuantState(
+            absmax=qabsmax,
+            code=code,
+            blocksize=blocksize,
+            dtype=A.dtype,
+            offset=offset,
+            state2=state2,
+        )
     else:
         quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
 
@@ -809,7 +915,7 @@ def dequantize_blockwise(
     code: Optional[torch.Tensor] = None,
     out: Optional[torch.Tensor] = None,
     blocksize: int = 4096,
-    nested=False
+    nested=False,
 ) -> Tensor:
     """
     Dequantizes blockwise quantized values.
@@ -843,43 +949,76 @@ def dequantize_blockwise(
         code = name2qmap["dynamic"]
 
     if quant_state is None:
-       quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)
+        quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)
 
     absmax = quant_state.absmax
     if quant_state.nested:
         absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
         absmax += quant_state.offset
-        if absmax.dtype != torch.float32: absmax = absmax.float()
+        if absmax.dtype != torch.float32:
+            absmax = absmax.float()
 
     if out is None:
         out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
 
-    if A.device.type != 'cpu':
+    if A.device.type != "cpu":
         device = pre_call(A.device)
         code = quant_state.code.to(A.device)
         if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
-            raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
+            raise ValueError(
+                f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
+            )
         is_on_gpu([A, absmax, out])
         if out.dtype == torch.float32:
-            lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
+            lib.cdequantize_blockwise_fp32(
+                get_ptr(quant_state.code),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int(quant_state.blocksize),
+                ct.c_int(A.numel()),
+            )
         elif out.dtype == torch.float16:
-            lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
+            lib.cdequantize_blockwise_fp16(
+                get_ptr(quant_state.code),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int(quant_state.blocksize),
+                ct.c_int(A.numel()),
+            )
         elif out.dtype == torch.bfloat16:
-            lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
+            lib.cdequantize_blockwise_bf16(
+                get_ptr(quant_state.code),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int(quant_state.blocksize),
+                ct.c_int(A.numel()),
+            )
         else:
             raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
         post_call(A.device)
     else:
         code = quant_state.code.cpu()
-        lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel()))
+        lib.cdequantize_blockwise_cpu_fp32(
+            get_ptr(code),
+            get_ptr(A),
+            get_ptr(quant_state.absmax),
+            get_ptr(out),
+            ct.c_longlong(quant_state.blocksize),
+            ct.c_longlong(A.numel()),
+        )
 
     return out
 
+
 def get_4bit_type(typename, device=None, blocksize=64):
-    if device is None: device = 'cuda'
+    if device is None:
+        device = "cuda"
     data = None
-    if typename == 'nf4':
-        ''' Implements the NF4 data type.
+    if typename == "nf4":
+        """ Implements the NF4 data type.
 
             Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
             is normalized into the range [-1, 1].
@@ -888,12 +1027,26 @@ def get_4bit_type(typename, device=None, blocksize=64):
 
             Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
             the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
-        '''
-        data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635,
-                -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
-                0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
-                0.7229568362236023, 1.0]
-    elif typename == 'fp4':
+        """
+        data = [
+            -1.0,
+            -0.6961928009986877,
+            -0.5250730514526367,
+            -0.39491748809814453,
+            -0.28444138169288635,
+            -0.18477343022823334,
+            -0.09105003625154495,
+            0.0,
+            0.07958029955625534,
+            0.16093020141124725,
+            0.24611230194568634,
+            0.33791524171829224,
+            0.44070982933044434,
+            0.5626170039176941,
+            0.7229568362236023,
+            1.0,
+        ]
+    elif typename == "fp4":
         # 0b000 = 0
         # 0b001 = 0.0625
         # 0b010 = 8
@@ -904,20 +1057,35 @@ def get_4bit_type(typename, device=None, blocksize=64):
         # 0b111 = 3
         # can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
         data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
-    elif typename == 'int4':
+    elif typename == "int4":
         data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
-    elif typename == 'af4':
+    elif typename == "af4":
         # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
         # https://arxiv.org/abs/2306.06965
         if blocksize == 64:
-            data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478,
-                    -0.04934812,  0., 0.04273164, 0.12934483, 0.21961274, 0.31675666,
-                    0.42563882,  0.55496234,  0.72424863,  1.][::-1]
+            data = [
+                -1.0,
+                -0.69441008,
+                -0.51243739,
+                -0.3736951,
+                -0.25607552,
+                -0.14982478,
+                -0.04934812,
+                0.0,
+                0.04273164,
+                0.12934483,
+                0.21961274,
+                0.31675666,
+                0.42563882,
+                0.55496234,
+                0.72424863,
+                1.0,
+            ][::-1]
         else:
-            raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.')
+            raise NotImplementedError("4-bit AbnormalFloats currently only support blocksize 64.")
 
     if data is None:
-        raise NotImplementedError(f'Typename {typename} not supported')
+        raise NotImplementedError(f"Typename {typename} not supported")
 
     data = Tensor(data)
     data /= data.abs().max()
@@ -926,11 +1094,26 @@ def get_4bit_type(typename, device=None, blocksize=64):
     return data.to(device)
 
 
-def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
-    return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage)
+def quantize_fp4(
+    A: Tensor,
+    absmax: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    blocksize=64,
+    compress_statistics=False,
+    quant_storage=torch.uint8,
+):
+    return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
 
-def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
-    return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage)
+
+def quantize_nf4(
+    A: Tensor,
+    absmax: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    blocksize=64,
+    compress_statistics=False,
+    quant_storage=torch.uint8,
+):
+    return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
 
 
 def quantize_4bit(
@@ -939,7 +1122,7 @@ def quantize_4bit(
     out: Optional[torch.Tensor] = None,
     blocksize=64,
     compress_statistics=False,
-    quant_type='fp4',
+    quant_type="fp4",
     quant_storage=torch.uint8,
 ) -> Tuple[Tensor, QuantState]:
     """
@@ -967,10 +1150,10 @@ def quantize_4bit(
     tuple(torch.Tensor, torch.Size, torch.dtype, int):
         The quantization state to undo the quantization.
     """
-    if A.device.type != 'cuda':
-        raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}')
-    if quant_type not in ['fp4', 'nf4']:
-        raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
+    if A.device.type != "cuda":
+        raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
+    if quant_type not in ["fp4", "nf4"]:
+        raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
 
     n = A.numel()
     input_shape = A.shape
@@ -980,10 +1163,9 @@ def quantize_4bit(
         blocks += 1 if n % blocksize > 0 else 0
         absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
 
-
     if out is None:
         mod = dtype2bytes[quant_storage] * 2
-        out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device)
+        out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)
 
     assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
 
@@ -991,20 +1173,62 @@ def quantize_4bit(
     is_on_gpu([A, out, absmax])
 
     if A.dtype == torch.float32:
-        if quant_type == 'fp4':
-            lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
+        if quant_type == "fp4":
+            lib.cquantize_blockwise_fp32_fp4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int32(blocksize),
+                ct.c_int(n),
+            )
         else:
-            lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
+            lib.cquantize_blockwise_fp32_nf4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int32(blocksize),
+                ct.c_int(n),
+            )
     elif A.dtype == torch.float16:
-        if quant_type == 'fp4':
-            lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
+        if quant_type == "fp4":
+            lib.cquantize_blockwise_fp16_fp4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int32(blocksize),
+                ct.c_int(n),
+            )
         else:
-            lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
+            lib.cquantize_blockwise_fp16_nf4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int32(blocksize),
+                ct.c_int(n),
+            )
     elif A.dtype == torch.bfloat16:
-        if quant_type == 'fp4':
-            lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
+        if quant_type == "fp4":
+            lib.cquantize_blockwise_bf16_fp4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int32(blocksize),
+                ct.c_int(n),
+            )
         else:
-            lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
+            lib.cquantize_blockwise_bf16_nf4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int32(blocksize),
+                ct.c_int(n),
+            )
     else:
         raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
     post_call(A.device)
@@ -1016,19 +1240,57 @@ def quantize_4bit(
         absmax -= offset
         qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
         del absmax
-        state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2)
+        state = QuantState(
+            absmax=qabsmax,
+            shape=input_shape,
+            dtype=A.dtype,
+            blocksize=blocksize,
+            code=code,
+            quant_type=quant_type,
+            offset=offset,
+            state2=state2,
+        )
     else:
-        state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, )
+        state = QuantState(
+            absmax=absmax,
+            shape=input_shape,
+            dtype=A.dtype,
+            blocksize=blocksize,
+            code=code,
+            quant_type=quant_type,
+        )
 
     return out, state
 
-def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor:
-    return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
 
-def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor:
-    return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
+def dequantize_fp4(
+    A: Tensor,
+    quant_state: Optional[QuantState] = None,
+    absmax: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    blocksize: int = 64,
+) -> Tensor:
+    return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
 
-def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
+
+def dequantize_nf4(
+    A: Tensor,
+    quant_state: Optional[QuantState] = None,
+    absmax: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    blocksize: int = 64,
+) -> Tensor:
+    return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
+
+
+def dequantize_4bit(
+    A: Tensor,
+    quant_state: Optional[QuantState] = None,
+    absmax: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    blocksize: int = 64,
+    quant_type="fp4",
+) -> Tensor:
     """
     Dequantizes FP4 blockwise quantized values.
 
@@ -1056,23 +1318,31 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax:
         Dequantized tensor.
     """
     if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
-        raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
-    if quant_type not in ['fp4', 'nf4']:
-        raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
+        raise ValueError(
+            f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
+        )
+    if quant_type not in ["fp4", "nf4"]:
+        raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
 
     if quant_state is None:
         assert absmax is not None and out is not None
 
-        quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type)
+        quant_state = QuantState(
+            absmax=absmax,
+            shape=out.shape,
+            dtype=out.dtype,
+            blocksize=blocksize,
+            quant_type=quant_type,
+        )
 
     else:
         absmax = quant_state.absmax
 
-
     if quant_state.nested:
         absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
         absmax += quant_state.offset
-        if absmax.dtype != torch.float32: absmax = absmax.float()
+        if absmax.dtype != torch.float32:
+            absmax = absmax.float()
 
     if out is None:
         out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
@@ -1082,27 +1352,71 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax:
     device = pre_call(A.device)
     is_on_gpu([A, absmax, out])
     if out.dtype == torch.float32:
-        if quant_state.quant_type == 'fp4':
-            lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
+        if quant_state.quant_type == "fp4":
+            lib.cdequantize_blockwise_fp32_fp4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int(quant_state.blocksize),
+                ct.c_int(n),
+            )
         else:
-            lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
+            lib.cdequantize_blockwise_fp32_nf4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int(quant_state.blocksize),
+                ct.c_int(n),
+            )
     elif out.dtype == torch.float16:
-        if quant_state.quant_type == 'fp4':
-            lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
+        if quant_state.quant_type == "fp4":
+            lib.cdequantize_blockwise_fp16_fp4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int(quant_state.blocksize),
+                ct.c_int(n),
+            )
         else:
-            lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
+            lib.cdequantize_blockwise_fp16_nf4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int(quant_state.blocksize),
+                ct.c_int(n),
+            )
     elif out.dtype == torch.bfloat16:
-        if quant_state.quant_type == 'fp4':
-            lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
+        if quant_state.quant_type == "fp4":
+            lib.cdequantize_blockwise_bf16_fp4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int(quant_state.blocksize),
+                ct.c_int(n),
+            )
         else:
-            lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
+            lib.cdequantize_blockwise_bf16_nf4(
+                get_ptr(None),
+                get_ptr(A),
+                get_ptr(absmax),
+                get_ptr(out),
+                ct.c_int(quant_state.blocksize),
+                ct.c_int(n),
+            )
     else:
         raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
     post_call(A.device)
 
-    is_transposed = (True if A.shape[0] == 1 else False)
-    if is_transposed: return out.t()
-    else: return out
+    is_transposed = True if A.shape[0] == 1 else False
+    if is_transposed:
+        return out.t()
+    else:
+        return out
 
 
 def quantize(
@@ -1117,7 +1431,8 @@ def quantize(
         code = code.to(A.device)
 
     absmax = torch.abs(A).max()
-    if absmax.dtype != torch.float32: absmax = absmax.float()
+    if absmax.dtype != torch.float32:
+        absmax = absmax.float()
     inp = A / absmax
     out = quantize_no_absmax(inp, code, out)
     return out, (absmax, code)
@@ -1144,7 +1459,7 @@ def dequantize(
 
 
 def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
-    '''
+    """
     Quantizes input tensor to 8-bit.
 
     Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
@@ -1163,9 +1478,10 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
     -------
     torch.Tensor:
         Quantized 8-bit tensor.
-    '''
+    """
     prev_device = pre_call(A.device)
-    if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+    if out is None:
+        out = torch.zeros_like(A, dtype=torch.uint8)
     is_on_gpu([A, out])
     lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
     post_call(prev_device)
@@ -1173,7 +1489,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
 
 
 def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
-    '''
+    """
     Dequantizes the 8-bit tensor to 32-bit.
 
     Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
@@ -1192,9 +1508,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
     -------
     torch.Tensor:
         32-bit output tensor.
-    '''
+    """
     prev_device = pre_call(A.device)
-    if out is None: out = torch.zeros_like(A, dtype=torch.float32)
+    if out is None:
+        out = torch.zeros_like(A, dtype=torch.float32)
     is_on_gpu([code, A, out])
     lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
     post_call(prev_device)
@@ -1261,16 +1578,17 @@ def optimizer_update_32bit(
     if max_unorm > 0.0:
         param_norm = torch.norm(p.data.float())
 
-
     optim_func = None
     if g.dtype == torch.float32:
         optim_func = str2optimizer32bit[optimizer_name][0]
     elif g.dtype == torch.float16:
         optim_func = str2optimizer32bit[optimizer_name][1]
-    elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3):
+    elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3:
         optim_func = str2optimizer32bit[optimizer_name][2]
     else:
-        raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}")
+        raise ValueError(
+            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
+        )
 
     is_on_gpu([g, p, state1, state2, unorm_vec])
     prev_device = pre_call(g.device)
@@ -1290,7 +1608,8 @@ def optimizer_update_32bit(
         ct.c_float(lr),
         ct.c_float(gnorm_scale),
         ct.c_bool(skip_zeros),
-        ct.c_int32(g.numel()))
+        ct.c_int32(g.numel()),
+    )
     post_call(prev_device)
 
 
@@ -1422,7 +1741,7 @@ def optimizer_update_8bit(
         )
     else:
         raise ValueError(
-            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
+            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
         )
     post_call(prev_device)
 
@@ -1446,7 +1765,6 @@ def optimizer_update_8bit_blockwise(
     gnorm_scale: float = 1.0,
     skip_zeros=False,
 ) -> None:
-
     optim_func = None
     prev_device = pre_call(g.device)
     is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
@@ -1454,12 +1772,15 @@ def optimizer_update_8bit_blockwise(
         optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
     elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
         optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
-    elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and
-          len(str2optimizer8bit_blockwise[optimizer_name])==3):
+    elif (
+        g.dtype == torch.bfloat16
+        and state1.dtype == torch.uint8
+        and len(str2optimizer8bit_blockwise[optimizer_name]) == 3
+    ):
         optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
     else:
         raise ValueError(
-            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
+            f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
         )
     post_call(prev_device)
 
@@ -1487,9 +1808,8 @@ def optimizer_update_8bit_blockwise(
     )
     post_call(prev_device)
 
-def percentile_clipping(
-    grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
-):
+
+def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5):
     """Applies percentile clipping
 
     grad: torch.Tensor
@@ -1531,9 +1851,7 @@ def percentile_clipping(
     return current_gnorm, clip_value, gnorm_scale
 
 
-def histogram_scatter_add_2d(
-    histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
-):
+def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
     assert len(histogram.shape) == 2
     assert histogram.dtype == torch.float32
     assert source.dtype == torch.float32
@@ -1550,12 +1868,12 @@ def histogram_scatter_add_2d(
     is_on_gpu([histogram, index1, index2, source])
     lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
 
+
 def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
-    if not torch.cuda.is_initialized(): torch.cuda.init()
+    if not torch.cuda.is_initialized():
+        torch.cuda.init()
     if A.dtype != expected_type or B.dtype != expected_type:
-        raise TypeError(
-            f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
-        )
+        raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}")
 
     sA = A.shape
     sB = B.shape
@@ -1596,12 +1914,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
         sout = out.shape
         # special case common in backprop
         if not correct and len(sA) == 3 and len(sB) == 3:
-            if (
-                sout[0] == sA[2]
-                and sout[1] == sB[2]
-                and sA[0] == sB[0]
-                and sA[1] == sB[1]
-            ):
+            if sout[0] == sA[2] and sout[1] == sB[2] and sA[0] == sB[0] and sA[1] == sB[1]:
                 correct = True
     else:
         if len(sA) == 2 and len(sB) == 2:
@@ -1634,26 +1947,29 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
 
     if not correct:
         raise ValueError(
-            f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}."
+            f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.",
         )
 
     return sout
 
+
 def gemv_4bit(
     A: Tensor,
     B: Tensor,
     out: Optional[torch.Tensor] = None,
     transposed_A=False,
     transposed_B=False,
-    state=None
+    state=None,
 ):
     prev_device = pre_call(A.device)
-    #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
+    # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
     if state is None:
-        raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
+        raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )")
 
     if A.numel() != A.shape[-1]:
-        raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
+        raise ValueError(
+            'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]',
+        )
 
     Bshape = state.shape
     bout = Bshape[0]
@@ -1673,7 +1989,7 @@ def gemv_4bit(
     k = Bshape[1]
     lda = Bshape[0]
     ldc = Bshape[0]
-    ldb = (A.shape[-1]+1)//2
+    ldb = (A.shape[-1] + 1) // 2
     is_on_gpu([B, A, out, absmax, state.code])
     m = ct.c_int32(m)
     n = ct.c_int32(n)
@@ -1684,21 +2000,61 @@ def gemv_4bit(
 
     if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
         if A.dtype == torch.float16:
-            lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
+            lib.cgemm_4bit_inference_naive_fp16(
+                m,
+                n,
+                k,
+                get_ptr(A),
+                get_ptr(B),
+                get_ptr(absmax),
+                get_ptr(state.code),
+                get_ptr(out),
+                lda,
+                ldb,
+                ldc,
+                ct.c_int32(state.blocksize),
+            )
         elif A.dtype == torch.bfloat16:
-            lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
+            lib.cgemm_4bit_inference_naive_bf16(
+                m,
+                n,
+                k,
+                get_ptr(A),
+                get_ptr(B),
+                get_ptr(absmax),
+                get_ptr(state.code),
+                get_ptr(out),
+                lda,
+                ldb,
+                ldc,
+                ct.c_int32(state.blocksize),
+            )
         elif A.dtype == torch.float32:
-            lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
+            lib.cgemm_4bit_inference_naive_fp32(
+                m,
+                n,
+                k,
+                get_ptr(A),
+                get_ptr(B),
+                get_ptr(absmax),
+                get_ptr(state.code),
+                get_ptr(out),
+                lda,
+                ldb,
+                ldc,
+                ct.c_int32(state.blocksize),
+            )
         else:
-            raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
+            raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
 
     else:
-        raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
+        raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
 
     post_call(prev_device)
 
     return out
 
+
 def igemm(
     A: Tensor,
     B: Tensor,
@@ -1764,7 +2120,7 @@ def igemm(
         assert len(sA) == 3
         if not (sA[0] == sB[0] and sA[1] == sB[1]):
             raise ValueError(
-                f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
+                f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}",
             )
 
         transposed_A = True
@@ -1783,8 +2139,20 @@ def igemm(
     # B^T @ A^T = C^T
     # [km, nk -> mn]
     is_on_gpu([B, A, out])
-    lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
-               get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
+    lib.cigemm(
+        ptr,
+        ct.c_bool(transposed_B),
+        ct.c_bool(transposed_A),
+        ct.c_int32(m),
+        ct.c_int32(n),
+        ct.c_int32(k),
+        get_ptr(B),
+        get_ptr(A),
+        get_ptr(out),
+        ct.c_int32(lda),
+        ct.c_int32(ldb),
+        ct.c_int32(ldc),
+    )
     return out
 
 
@@ -1796,9 +2164,7 @@ def batched_igemm(
     transposed_B=False,
 ):
     if not len(A.shape) == 3 or not len(B.shape) == 3:
-        raise ValueError(
-            f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
-        )
+        raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}")
     sout = check_matmul(A, B, out, transposed_A, transposed_B)
     if out is None:
         out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
@@ -1865,9 +2231,24 @@ def batched_igemm(
     ptr = CUBLAS_Context.get_instance().get_context(A.device)
 
     is_on_gpu([B, A, out])
-    lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
-               get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
-               ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
+    lib.cbatched_igemm(
+        ptr,
+        ct.c_bool(transposed_B),
+        ct.c_bool(transposed_A),
+        ct.c_int32(m),
+        ct.c_int32(n),
+        ct.c_int32(k),
+        get_ptr(B),
+        get_ptr(A),
+        get_ptr(out),
+        ct.c_int32(lda),
+        ct.c_int32(ldb),
+        ct.c_int32(ldc),
+        ct.c_long(strideA),
+        ct.c_long(strideB),
+        ct.c_long(strideC),
+        ct.c_uint32(num_batch),
+    )
     return out
 
 
@@ -1876,14 +2257,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
     shapeB = SB[0]
     dimsA = len(shapeA)
     dimsB = len(shapeB)
-    assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
+    assert dimsB == 2, "Only two dimensional matrices are supported for argument B"
     if dimsA == 2:
         m = shapeA[0]
     elif dimsA == 3:
         m = shapeA[0] * shapeA[1]
 
     rows = n = shapeB[0]
-    assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
+    assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}"
 
     # if the tensor is empty, return a transformed empty tensor with the right dimensions
     if shapeA[0] == 0 and dimsA == 2:
@@ -1892,13 +2273,9 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
         return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
 
     if dimsA == 2 and out is None:
-        out, Sout = get_transform_buffer(
-            (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
-        )
+        out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row")
     elif dimsA == 3 and out is None:
-        out, Sout = get_transform_buffer(
-            (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
-        )
+        out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row")
 
     assert dimsB != 3, "len(B.shape)==3 not supported"
     assert A.device.type == "cuda"
@@ -1940,49 +2317,33 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
     has_error = 0
     ptrRowScale = get_ptr(None)
     is_on_gpu([A, B, out])
-    if formatB == 'col_turing':
+    if formatB == "col_turing":
         if dtype == torch.int32:
-            has_error = lib.cigemmlt_turing_32(
-                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
-            )
+            has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
         else:
-            has_error = lib.cigemmlt_turing_8(
-                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
-            )
+            has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
     elif formatB == "col_ampere":
         if dtype == torch.int32:
-            has_error = lib.cigemmlt_ampere_32(
-                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
-            )
+            has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
         else:
-            has_error = lib.cigemmlt_ampere_8(
-                ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
-            )
+            has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
 
     if has_error == 100:  # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
         raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)")
 
     if has_error:
-        print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
-        raise Exception('cublasLt ran into an error!')
+        print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}")
+        raise Exception("cublasLt ran into an error!")
 
     torch.cuda.set_device(prev_device)
 
     return out, Sout
 
 
-def mm_dequant(
-    A,
-    quant_state,
-    row_stats,
-    col_stats,
-    out=None,
-    new_row_stats=None,
-    new_col_stats=None,
-    bias=None
-):
+def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None):
     assert A.dtype == torch.int32
-    if bias is not None: assert bias.dtype == torch.float16
+    if bias is not None:
+        assert bias.dtype == torch.float16
     out_shape = quant_state[0]
     if len(out_shape) == 3:
         out_shape = (out_shape[0] * out_shape[1], out_shape[2])
@@ -1990,19 +2351,11 @@ def mm_dequant(
     if out is None:
         out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
     if new_row_stats is None:
-        new_row_stats = torch.empty(
-            out_shape[0], dtype=torch.float32, device=A.device
-        )
+        new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
     if new_col_stats is None:
-        new_col_stats = torch.empty(
-            out_shape[1], dtype=torch.float32, device=A.device
-        )
-    assert (
-        new_row_stats.shape[0] == row_stats.shape[0]
-    ), f"{new_row_stats.shape} vs {row_stats.shape}"
-    assert (
-        new_col_stats.shape[0] == col_stats.shape[0]
-    ), f"{new_col_stats.shape} vs {col_stats.shape}"
+        new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
+    assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
+    assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
 
     prev_device = pre_call(A.device)
     ptrA = get_ptr(A)
@@ -2016,15 +2369,23 @@ def mm_dequant(
     numCols = ct.c_int32(out_shape[1])
 
     is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
-    lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols)
+    lib.cdequant_mm_int32_fp16(
+        ptrA,
+        ptrRowStats,
+        ptrColStats,
+        ptrOut,
+        ptrNewRowStats,
+        ptrNewColStats,
+        ptrBias,
+        numRows,
+        numCols,
+    )
     post_call(prev_device)
 
     return out
 
 
-def get_colrow_absmax(
-    A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
-):
+def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
     assert A.dtype == torch.float16
     device = A.device
 
@@ -2037,18 +2398,12 @@ def get_colrow_absmax(
     col_tiles = (cols + 255) // 256
     tiled_rows = ((rows + 15) // 16) * 16
     if row_stats is None:
-        row_stats = torch.empty(
-            (rows,), dtype=torch.float32, device=device
-        ).fill_(-50000.0)
+        row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
     if col_stats is None:
-        col_stats = torch.empty(
-            (cols,), dtype=torch.float32, device=device
-        ).fill_(-50000.0)
+        col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)
 
     if nnz_block_ptr is None and threshold > 0.0:
-        nnz_block_ptr = torch.zeros(
-            ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
-        )
+        nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device)
 
     ptrA = get_ptr(A)
     ptrRowStats = get_ptr(row_stats)
@@ -2122,14 +2477,10 @@ def __init__(self, rows, cols, nnz, colptr, rowidx, values):
 def coo2csr(cooA):
     values, counts = torch.unique(cooA.rowidx, return_counts=True)
     values.add_(1)
-    rowptr = torch.zeros(
-        (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
-    )
+    rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
     rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
     rowptr.cumsum_(0)
-    return CSRSparseTensor(
-        cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
-    )
+    return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)
 
 
 def coo2csc(cooA):
@@ -2138,14 +2489,10 @@ def coo2csc(cooA):
     values = cooA.values[col2rowidx]
     colvalues, counts = torch.unique(val, return_counts=True)
     colvalues.add_(1)
-    colptr = torch.zeros(
-        (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
-    )
+    colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
     colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
     colptr.cumsum_(0)
-    return CSCSparseTensor(
-        cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
-    )
+    return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
 
 
 def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
@@ -2155,9 +2502,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
     return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
 
 
-def double_quant(
-    A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
-):
+def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
     device = A.device
     assert A.dtype == torch.half
     assert device.type == "cuda"
@@ -2170,9 +2515,7 @@ def double_quant(
         rows = A.shape[0]
 
     if row_stats is None or col_stats is None:
-        row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
-            A, threshold=threshold
-        )
+        row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
 
     if out_col is None:
         out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
@@ -2190,9 +2533,7 @@ def double_quant(
     if threshold > 0.0:
         nnz = nnz_row_ptr[-1].item()
         if nnz > 0:
-            coo_tensor = coo_zeros(
-                A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
-            )
+            coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
             ptrRowIdx = get_ptr(coo_tensor.rowidx)
             ptrColIdx = get_ptr(coo_tensor.colidx)
             ptrVal = get_ptr(coo_tensor.values)
@@ -2251,12 +2592,16 @@ def double_quant(
     return out_row, out_col, row_stats, col_stats, coo_tensor
 
 
-def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
+def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
     prev_device = pre_call(A.device)
-    if state is None: state = (A.shape, from_order)
-    else: from_order = state[1]
-    if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
-    else: new_state = (state[0], to_order) # (shape, order)
+    if state is None:
+        state = (A.shape, from_order)
+    else:
+        from_order = state[1]
+    if out is None:
+        out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
+    else:
+        new_state = (state[0], to_order)  # (shape, order)
 
     shape = state[0]
     if len(shape) == 2:
@@ -2267,7 +2612,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
         dim2 = ct.c_int32(shape[2])
 
     is_on_gpu([A, out])
-    if to_order == 'col32':
+    if to_order == "col32":
         if transpose:
             lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
         else:
@@ -2288,7 +2633,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
         elif from_order == "col_ampere":
             lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
     else:
-        raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
+        raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}")
 
     post_call(prev_device)
 
@@ -2297,9 +2642,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
 
 def spmm_coo(cooA, B, out=None):
     if out is None:
-        out = torch.empty(
-            (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
-        )
+        out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
     nnz = cooA.nnz
     assert cooA.rowidx.numel() == nnz
     assert cooA.colidx.numel() == nnz
@@ -2326,16 +2669,28 @@ def spmm_coo(cooA, B, out=None):
     cldc = ct.c_int32(ldc)
 
     is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
-    lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
+    lib.cspmm_coo(
+        ptr,
+        ptrRowidx,
+        ptrColidx,
+        ptrValues,
+        cnnz,
+        crowsA,
+        ccolsA,
+        ccolsB,
+        cldb,
+        ptrB,
+        cldc,
+        ptrC,
+        ct.c_bool(transposed_B),
+    )
 
     return out
 
 
 def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
     if out is None:
-        out = torch.zeros(
-            (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
-        )
+        out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
     nnz = cooA.nnz
     prev_device = pre_call(B.device)
     assert cooA.rowidx.numel() == nnz
@@ -2353,9 +2708,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
     max_count, max_idx = torch.sort(counts, descending=True)
     max_idx = max_idx.int()
     max_count = max_count.int()
-    assert (
-        max_count[0] <= 32
-    ), f"Current max count per row is 8 but found {max_count[0]}."
+    assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}."
     assert B.dtype in [torch.float16, torch.int8]
     ptrOffset = get_ptr(offset)
     ptrMaxCount = get_ptr(max_count)
@@ -2443,9 +2796,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
     elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
         dtype = x.dtype
         x = x.float()
-        dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(
-            x, dim=dim, keepdim=True
-        )
+        dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)
         dyna[dyna == 0] = 1
         qx = 255.0 / dyna
         minx = torch.amin(x, dim=dim, keepdim=True)
@@ -2553,9 +2904,7 @@ def extract_outliers(A, SA, idx):
     assert formatA in ["col_turing", "col_ampere"]
     assert A.device.type == "cuda"
 
-    out = torch.zeros(
-        (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
-    )
+    out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
 
     idx_size = ct.c_int32(idx.numel())
     rows = ct.c_int32(shapeA[0])
@@ -2565,7 +2914,7 @@ def extract_outliers(A, SA, idx):
     ptrOut = get_ptr(out)
 
     prev_device = pre_call(A.device)
-    if formatA == 'col_turing':
+    if formatA == "col_turing":
         lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
     elif formatA == "col_ampere":
         lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
@@ -2573,6 +2922,7 @@ def extract_outliers(A, SA, idx):
 
     return out
 
+
 def pipeline_test(A, batch_size):
     out = torch.zeros_like(A)
     lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index f7b96205b..e1cc6600d 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -44,6 +44,7 @@ class StableEmbedding(torch.nn.Embedding):
         reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
         forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
     """
+
     def __init__(
         self,
         num_embeddings: int,
@@ -89,9 +90,7 @@ def __init__(
             dtype,
         )
         self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
-        GlobalOptimManager.get_instance().register_module_override(
-            self, "weight", {"optim_bits": 32}
-        )
+        GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
 
     def reset_parameters(self) -> None:
         torch.nn.init.xavier_uniform_(self.weight)
@@ -130,6 +129,7 @@ class Embedding(torch.nn.Embedding):
     """
     Embedding class to store and retrieve word embeddings from their indices.
     """
+
     def __init__(
         self,
         num_embeddings: int,
@@ -170,11 +170,9 @@ def __init__(
             scale_grad_by_freq,
             sparse,
             _weight,
-            device=device
-        )
-        GlobalOptimManager.get_instance().register_module_override(
-            self, "weight", {"optim_bits": 32}
+            device=device,
         )
+        GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
 
     def reset_parameters(self) -> None:
         torch.nn.init.xavier_uniform_(self.weight)
@@ -208,16 +206,16 @@ def forward(self, input: Tensor) -> Tensor:
 
 class Params4bit(torch.nn.Parameter):
     def __new__(
-            cls,
-            data: Optional[torch.Tensor] = None,
-            requires_grad=False,  # quantized weights should be frozen by default
-            quant_state: Optional[QuantState] = None,
-            blocksize: int = 64,
-            compress_statistics: bool = True,
-            quant_type: str = 'fp4',
-            quant_storage: torch.dtype = torch.uint8,
-            module: Optional["Linear4bit"] = None,
-            bnb_quantized: bool = False
+        cls,
+        data: Optional[torch.Tensor] = None,
+        requires_grad=False,  # quantized weights should be frozen by default
+        quant_state: Optional[QuantState] = None,
+        blocksize: int = 64,
+        compress_statistics: bool = True,
+        quant_type: str = "fp4",
+        quant_storage: torch.dtype = torch.uint8,
+        module: Optional["Linear4bit"] = None,
+        bnb_quantized: bool = False,
     ) -> "Params4bit":
         if data is None:
             data = torch.empty(0)
@@ -250,7 +248,7 @@ def __setstate__(self, state):
         self.bnb_quantized = state["bnb_quantized"]
         self.module = state["module"]
 
-    def __deepcopy__(self,memo):
+    def __deepcopy__(self, memo):
         new_instance = type(self).__new__(type(self))
         state = self.__getstate__()
         new_instance.__setstate__(state)
@@ -265,7 +263,14 @@ def __copy__(self):
         return new_instance
 
     @classmethod
-    def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
+    def from_prequantized(
+        cls,
+        data: torch.Tensor,
+        quantized_stats: Dict[str, Any],
+        requires_grad: bool = False,
+        device="cuda",
+        **kwargs,
+    ) -> "Params4bit":
         self = torch.Tensor._make_subclass(cls, data.to(device))
         self.requires_grad = requires_grad
         self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
@@ -292,33 +297,39 @@ def _quantize(self, device):
         return self
 
     def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
-        return self.to(device='cuda' if device is None else device, non_blocking=non_blocking)
+        return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
 
     @overload
-    def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
-        ...
+    def to(
+        self: T,
+        device: Optional[Union[int, device]] = ...,
+        dtype: Optional[Union[dtype, str]] = ...,
+        non_blocking: bool = ...,
+    ) -> T: ...
 
     @overload
-    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
-        ...
+    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
 
     @overload
-    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
-        ...
+    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
 
     def to(self, *args, **kwargs):
         device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
 
-        if (device is not None and device.type == "cuda" and not self.bnb_quantized):
+        if device is not None and device.type == "cuda" and not self.bnb_quantized:
             return self._quantize(device)
         else:
             if self.quant_state is not None:
                 self.quant_state.to(device)
 
-            new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
-                                   requires_grad=self.requires_grad, quant_state=self.quant_state,
-                                   blocksize=self.blocksize, compress_statistics=self.compress_statistics,
-                                   quant_type=self.quant_type)
+            new_param = Params4bit(
+                super().to(device=device, dtype=dtype, non_blocking=non_blocking),
+                requires_grad=self.requires_grad,
+                quant_state=self.quant_state,
+                blocksize=self.blocksize,
+                compress_statistics=self.compress_statistics,
+                quant_type=self.quant_type,
+            )
 
             return new_param
 
@@ -355,7 +366,18 @@ class Linear4bit(nn.Linear):
     quantized_model = quantized_model.to(0) # Quantization happens here
     ```
     """
-    def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
+
+    def __init__(
+        self,
+        input_features,
+        output_features,
+        bias=True,
+        compute_dtype=None,
+        compress_statistics=True,
+        quant_type="fp4",
+        quant_storage=torch.uint8,
+        device=None,
+    ):
         """
         Initialize Linear4bit class.
 
@@ -368,7 +390,14 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non
                 Whether the linear class uses the bias term as well.
         """
         super().__init__(input_features, output_features, bias, device)
-        self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
+        self.weight = Params4bit(
+            self.weight.data,
+            requires_grad=False,
+            compress_statistics=compress_statistics,
+            quant_type=quant_type,
+            quant_storage=quant_storage,
+            module=self,
+        )
         # self.persistent_buffers = []  # TODO consider as way to save quant state
         self.compute_dtype = compute_dtype
         self.compute_type_is_set = False
@@ -385,11 +414,15 @@ def set_compute_type(self, x):
             if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
                 # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
                 # warn the user about this
-                warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.')
-                warnings.filterwarnings('ignore', message='.*inference.')
+                warnings.warn(
+                    "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
+                )
+                warnings.filterwarnings("ignore", message=".*inference.")
             if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
-                warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')
-                warnings.filterwarnings('ignore', message='.*inference or training')
+                warnings.warn(
+                    "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
+                )
+                warnings.filterwarnings("ignore", message=".*inference or training")
 
     def _save_to_state_dict(self, destination, prefix, keep_vars):
         """
@@ -407,8 +440,8 @@ def forward(self, x: torch.Tensor):
         if self.bias is not None and self.bias.dtype != x.dtype:
             self.bias.data = self.bias.data.to(x.dtype)
 
-        if getattr(self.weight, 'quant_state', None) is None:
-            if getattr(self, 'quant_state', None) is not None:
+        if getattr(self.weight, "quant_state", None) is None:
+            if getattr(self, "quant_state", None) is not None:
                 # the quant state got lost when the parameter got converted. This happens for example for fsdp
                 # since we registered the module, we can recover the state here
                 assert self.weight.shape[1] == 1
@@ -416,7 +449,9 @@ def forward(self, x: torch.Tensor):
                     self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
                 self.weight.quant_state = self.quant_state
             else:
-                print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
+                print(
+                    "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
+                )
         if not self.compute_type_is_set:
             self.set_compute_type(x)
             self.compute_type_is_set = True
@@ -437,7 +472,17 @@ class LinearFP4(Linear4bit):
     """
     Implements the FP4 data type.
     """
-    def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
+
+    def __init__(
+        self,
+        input_features,
+        output_features,
+        bias=True,
+        compute_dtype=None,
+        compress_statistics=True,
+        quant_storage=torch.uint8,
+        device=None,
+    ):
         """
         Args:
             input_features (`str`):
@@ -447,21 +492,40 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non
             bias (`bool`, defaults to `True`):
                 Whether the linear class uses the bias term as well.
         """
-        super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)
+        super().__init__(
+            input_features,
+            output_features,
+            bias,
+            compute_dtype,
+            compress_statistics,
+            "fp4",
+            quant_storage,
+            device,
+        )
 
 
 class LinearNF4(Linear4bit):
-    ''' Implements the NF4 data type.
+    """Implements the NF4 data type.
+
+    Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
+    is normalized into the range [-1, 1].
 
-        Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
-        is normalized into the range [-1, 1].
+    For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314)
 
-        For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314)
+    Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
+    the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
+    """
 
-        Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
-        the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
-    '''
-    def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
+    def __init__(
+        self,
+        input_features,
+        output_features,
+        bias=True,
+        compute_dtype=None,
+        compress_statistics=True,
+        quant_storage=torch.uint8,
+        device=None,
+    ):
         """
         Args:
             input_features (`str`):
@@ -471,7 +535,16 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non
             bias (`bool`, defaults to `True`):
                 Whether the linear class uses the bias term as well.
         """
-        super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)
+        super().__init__(
+            input_features,
+            output_features,
+            bias,
+            compute_dtype,
+            compress_statistics,
+            "nf4",
+            quant_storage,
+            device,
+        )
 
 
 class Int8Params(torch.nn.Parameter):
@@ -514,33 +587,22 @@ def to(
         device: Optional[Union[int, device]] = ...,
         dtype: Optional[Union[dtype, str]] = ...,
         non_blocking: bool = ...,
-    ) -> T:
-        ...
+    ) -> T: ...
 
     @overload
-    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
-        ...
+    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
 
     @overload
-    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
-        ...
+    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
 
     def to(self, *args, **kwargs):
-        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
-            *args, **kwargs
-        )
+        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
 
-        if (
-            device is not None
-            and device.type == "cuda"
-            and self.data.device.type == "cpu"
-        ):
+        if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
             return self.cuda(device)
         else:
             new_param = Int8Params(
-                super().to(
-                    device=device, dtype=dtype, non_blocking=non_blocking
-                ),
+                super().to(device=device, dtype=dtype, non_blocking=non_blocking),
                 requires_grad=self.requires_grad,
                 has_fp16_weights=self.has_fp16_weights,
             )
@@ -593,8 +655,18 @@ class Linear8bitLt(nn.Linear):
     int8_model = int8_model.to(0) # Quantization happens here
     ```
     """
-    def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
-                       memory_efficient_backward=False, threshold=0.0, index=None, device=None):
+
+    def __init__(
+        self,
+        input_features,
+        output_features,
+        bias=True,
+        has_fp16_weights=True,
+        memory_efficient_backward=False,
+        threshold=0.0,
+        index=None,
+        device=None,
+    ):
         """
         Initialize Linear8bitLt class.
 
@@ -647,19 +719,36 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
                 destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
                 destination[format_name] = self.state.formatB
 
-    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
-                              missing_keys, unexpected_keys, error_msgs):
-        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
-                                      error_msgs)
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        super()._load_from_state_dict(
+            state_dict,
+            prefix,
+            local_metadata,
+            strict,
+            missing_keys,
+            unexpected_keys,
+            error_msgs,
+        )
         unexpected_copy = list(unexpected_keys)
 
         for key in unexpected_copy:
-            input_name = key[len(prefix):]
+            input_name = key[len(prefix) :]
             if input_name == "SCB":
                 if self.weight.SCB is None:
                     # buffers not yet initialized, can't access them directly without quantizing first
-                    raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
-                                       "not supported. Please call module.cuda() before module.load_state_dict()")
+                    raise RuntimeError(
+                        "Loading a quantized checkpoint into non-quantized Linear8bitLt is "
+                        "not supported. Please call module.cuda() before module.load_state_dict()",
+                    )
 
                 input_param = state_dict[key]
                 self.weight.SCB.copy_(input_param)
@@ -702,18 +791,18 @@ def __init__(self, input_features, output_features, bias=True, device=None):
         self.is_quantized = False
 
     def forward_with_outliers(self, x, outlier_idx):
-        raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function')
+        raise NotImplementedError("Please override the `forward_with_outliers(self, x, outlier_idx)` function")
 
     def quantize_weight(self, w, outlier_idx):
-        raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function')
+        raise NotImplementedError("Please override the `quantize_weights(self, w, outlier_idx)` function")
 
     def forward(self, x):
         if self.outlier_dim is None:
             tracer = OutlierTracer.get_instance()
             if not tracer.is_initialized():
-                print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer')
+                print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
             outlier_idx = tracer.get_outliers(self.weight)
-            #print(outlier_idx, tracer.get_hvalue(self.weight))
+            # print(outlier_idx, tracer.get_hvalue(self.weight))
             self.outlier_dim = outlier_idx
 
         if not self.is_quantized:
@@ -721,6 +810,7 @@ def forward(self, x):
             self.weight.data.copy_(w)
             self.is_quantized = True
 
+
 class SwitchBackLinearBnb(nn.Linear):
     def __init__(
         self,
@@ -731,11 +821,9 @@ def __init__(
         memory_efficient_backward=False,
         threshold=0.0,
         index=None,
-        device=None
+        device=None,
     ):
-        super().__init__(
-            input_features, output_features, bias, device
-        )
+        super().__init__(input_features, output_features, bias, device)
         self.state = bnb.MatmulLtState()
         self.index = index
 
@@ -745,9 +833,7 @@ def __init__(
         if threshold > 0.0 and not has_fp16_weights:
             self.state.use_pool = True
 
-        self.weight = Int8Params(
-            self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
-        )
+        self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
 
     def init_8bit_state(self):
         self.state.CB = self.weight.CB
diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py
index 9c7738c59..aa8494942 100644
--- a/bitsandbytes/nn/triton_based_modules.py
+++ b/bitsandbytes/nn/triton_based_modules.py
@@ -22,7 +22,6 @@
 
 
 class _switchback_global(torch.autograd.Function):
-
     @staticmethod
     def forward(ctx, X_3D, W, bias):
         # reshape input to [N * L, D]
@@ -37,9 +36,7 @@ def forward(ctx, X_3D, W, bias):
 
         # matmult, fused dequant and add bias
         # call "mixed" because we are mixing rowwise quantized and global quantized
-        return int8_matmul_mixed_dequantize(
-            X_int8, W_int8.t(), state_X, state_W, bias
-        ).view(*X_3D.size()[:-1], -1)
+        return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
 
     @staticmethod
     def backward(ctx, G_3D):
@@ -56,7 +53,8 @@ def backward(ctx, G_3D):
             G_int8, state_G = quantize_rowwise(G)
             W_int8, state_W = quantize_global_transpose(W)
             grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
-                *G_3D.size()[:-1], -1
+                *G_3D.size()[:-1],
+                -1,
             )
         if ctx.needs_input_grad[1]:
             # backward pass uses standard weight grad
@@ -66,8 +64,8 @@ def backward(ctx, G_3D):
 
         return grad_X, grad_W, grad_bias
 
-class _switchback_vectorrize(torch.autograd.Function):
 
+class _switchback_vectorrize(torch.autograd.Function):
     @staticmethod
     def forward(ctx, X_3D, W, bias):
         # reshape input to [N * L, D]
@@ -81,9 +79,7 @@ def forward(ctx, X_3D, W, bias):
 
         # matmult, fused dequant and add bias
         # call kernel which expects rowwise quantized X and W
-        return int8_matmul_rowwise_dequantize(
-            X_int8, W_int8.t(), state_X, state_W, bias
-        ).view(*X_3D.size()[:-1], -1)
+        return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
 
     @staticmethod
     def backward(ctx, G_3D):
@@ -99,7 +95,8 @@ def backward(ctx, G_3D):
             G_int8, state_G = quantize_rowwise(G)
             W_int8, state_W = quantize_columnwise_and_transpose(W)
             grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
-                *G_3D.size()[:-1], -1
+                *G_3D.size()[:-1],
+                -1,
             )
         if ctx.needs_input_grad[1]:
             # backward pass uses standard weight grad
@@ -109,8 +106,8 @@ def backward(ctx, G_3D):
 
         return grad_X, grad_W, grad_bias
 
-class _switchback_global_mem_efficient(torch.autograd.Function):
 
+class _switchback_global_mem_efficient(torch.autograd.Function):
     @staticmethod
     def forward(ctx, X_3D, W, bias):
         # reshape input to [N * L, D]
@@ -127,9 +124,7 @@ def forward(ctx, X_3D, W, bias):
 
         # matmult, fused dequant and add bias
         # call "mixed" because we are mixing rowwise quantized and global quantized
-        return int8_matmul_mixed_dequantize(
-            X_int8, W_int8.t(), state_X, state_W, bias
-        ).view(*X_3D_sz[:-1], -1)
+        return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1)
 
     @staticmethod
     def backward(ctx, G_3D):
@@ -151,35 +146,34 @@ def backward(ctx, G_3D):
             G_int8, state_G = quantize_rowwise(G)
             del G
             W_int8 = W_int8.t().contiguous()
-            grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
-                *G_3D_sz[:-1], -1
-            )
+            grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1)
 
         return grad_X, grad_W, grad_bias
 
+
 class SwitchBackLinear(nn.Linear):
     def __init__(
-            self,
-            in_features: int,
-            out_features: int,
-            bias: bool = True,
-            device=None,
-            dtype=None,
-            vector_wise_quantization: bool = False,
-            mem_efficient : bool = False,
-        ):
+        self,
+        in_features: int,
+        out_features: int,
+        bias: bool = True,
+        device=None,
+        dtype=None,
+        vector_wise_quantization: bool = False,
+        mem_efficient: bool = False,
+    ):
         super().__init__(in_features, out_features, bias, device, dtype)
 
         if not is_triton_available():
-            raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
-                               Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
+            raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear.
+                               Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""")
 
         # By default, we use the global quantization.
         self.vector_wise_quantization = vector_wise_quantization
         if self.vector_wise_quantization:
             self._fn = _switchback_vectorrize
             if mem_efficient:
-                print('mem efficient is not supported for vector-wise quantization.')
+                print("mem efficient is not supported for vector-wise quantization.")
                 exit(1)
         else:
             if mem_efficient:
@@ -195,7 +189,7 @@ def prepare_for_eval(self):
         #     if hasattr(m, "prepare_for_eval"):
         #         m.prepare_for_eval()
         # model.apply(cond_prepare)
-        print('=> preparing for eval.')
+        print("=> preparing for eval.")
         if self.vector_wise_quantization:
             W_int8, state_W = quantize_rowwise(self.weight)
         else:
@@ -219,18 +213,22 @@ def forward(self, x):
             X_int8, state_X = quantize_rowwise(X)
 
             if self.vector_wise_quantization:
-                return int8_matmul_rowwise_dequantize(
-                    X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
-                ).view(*x.size()[:-1], -1)
+                return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
+                    *x.size()[:-1],
+                    -1,
+                )
             else:
-                return int8_matmul_mixed_dequantize(
-                    X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
-                ).view(*x.size()[:-1], -1)
+                return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
+                    *x.size()[:-1],
+                    -1,
+                )
+
 
 SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
 SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
 SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
 
+
 # This is just the standard linear function.
 class StandardLinearFunction(torch.autograd.Function):
     @staticmethod
@@ -260,7 +258,7 @@ def backward(ctx, grad_output_3D):
 
         return grad_input, grad_weight, grad_bias
 
-class StandardLinear(nn.Linear):
 
+class StandardLinear(nn.Linear):
     def forward(self, x):
         return StandardLinearFunction.apply(x, self.weight, self.bias)
diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py
index c2ea87ab0..aace548fa 100644
--- a/bitsandbytes/optim/adagrad.py
+++ b/bitsandbytes/optim/adagrad.py
@@ -50,9 +50,7 @@ def __init__(
         if not 0.0 <= lr:
             raise ValueError(f"Invalid learning rate: {lr}")
         if not 0.0 <= weight_decay:
-            raise ValueError(
-                f"Invalid weight_decay value: {weight_decay}"
-            )
+            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
         if not 0.0 <= eps:
             raise ValueError(f"Invalid epsilon value: {eps}")
         if initial_accumulator_value != 0.0:
@@ -119,9 +117,7 @@ def __init__(
         if not 0.0 <= lr:
             raise ValueError(f"Invalid learning rate: {lr}")
         if not 0.0 <= weight_decay:
-            raise ValueError(
-                f"Invalid weight_decay value: {weight_decay}"
-            )
+            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
         if not 0.0 <= eps:
             raise ValueError(f"Invalid epsilon value: {eps}")
         if initial_accumulator_value != 0.0:
@@ -189,9 +185,7 @@ def __init__(
         if not 0.0 <= lr:
             raise ValueError(f"Invalid learning rate: {lr}")
         if not 0.0 <= weight_decay:
-            raise ValueError(
-                f"Invalid weight_decay value: {weight_decay}"
-            )
+            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
         if not 0.0 <= eps:
             raise ValueError(f"Invalid epsilon value: {eps}")
         if initial_accumulator_value != 0.0:
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index e534c8b8f..d8ffca63e 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -14,8 +14,21 @@
 
 
 class Adam(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=0,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         Base Adam optimizer.
 
@@ -45,11 +58,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            optim_bits,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
+
 
 class Adam8bit(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=0,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         8-bit Adam optimizer.
 
@@ -79,11 +119,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            8,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
+
 
 class Adam32bit(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=0,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         32-bit Adam optimizer.
 
@@ -113,11 +180,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            32,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
+
 
 class PagedAdam(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=0,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         Paged Adam optimizer.
 
@@ -147,11 +241,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            optim_bits,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=True,
+        )
+
 
 class PagedAdam8bit(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=0,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         8-bit paged Adam optimizer.
 
@@ -181,11 +302,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            8,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=True,
+        )
+
 
 class PagedAdam32bit(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=0,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         Paged 32-bit Adam optimizer.
 
@@ -215,7 +363,21 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            32,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=True,
+        )
+
 
 class AnalysisAdam(torch.optim.Optimizer):
     """Adam that performs 8-bit vs 32-bit error analysis.
@@ -293,9 +455,7 @@ def step(self, closure=None):
                 if grad.dtype in {torch.float16, torch.bfloat16}:
                     grad = grad.float()
                 if grad.is_sparse:
-                    raise RuntimeError(
-                        "Adam does not support sparse gradients, please consider SparseAdam instead"
-                    )
+                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
                 amsgrad = group.get("amsgrad", False)
                 assert not amsgrad
 
@@ -312,15 +472,9 @@ def step(self, closure=None):
                     state["exp_avg"] = torch.zeros_like(p_data_fp32)
                     # Exponential moving average of squared gradient values
                     state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
-                    state["abserrors"] = torch.zeros(
-                        (256, 256), device=p_data_fp32.device
-                    )
-                    state["relerrors"] = torch.zeros(
-                        (256, 256), device=p_data_fp32.device
-                    )
-                    state["counts"] = torch.zeros(
-                        (256, 256), device=p_data_fp32.device
-                    )
+                    state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
+                    state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
+                    state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
                     if amsgrad:
                         # Maintains max of all exp. moving avg. of sq. grad. values
                         state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@@ -328,25 +482,19 @@ def step(self, closure=None):
                     state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
                     state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
                     if amsgrad:
-                        state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
-                            p_data_fp32
-                        )
+                        state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
 
                 state["step"] += 1
                 beta1, beta2 = group["betas"]
                 bias_correction1 = 1 - beta1 ** state["step"]
                 bias_correction2 = 1 - beta2 ** state["step"]
-                step_size = (
-                    group["lr"] * math.sqrt(bias_correction2) / bias_correction1
-                )
+                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
                 e = state["abserrors"]
                 rele = state["relerrors"]
                 counts = state["counts"]
 
                 if group["weight_decay"] != 0:
-                    p_data_fp32.add_(
-                        p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
-                    )
+                    p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
 
                 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                 if amsgrad:
@@ -359,10 +507,7 @@ def step(self, closure=None):
                 denom = exp_avg_sq.sqrt().add_(group["eps"])
                 update_fp32 = exp_avg / denom
 
-                if (
-                    p_data_fp32.numel() <= 8192
-                    or p_data_fp32.numel() > 50000 * 1000
-                ):
+                if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
                     # embedding layer or too small
                     p_data_fp32 += -step_size * update_fp32
                 else:
@@ -401,9 +546,7 @@ def step(self, closure=None):
                         # 3. dequantize
                         # Error will be calculated automatically!
                     else:
-                        raise ValueError(
-                            f"Invalid analysis value: {self.analysis}!"
-                        )
+                        raise ValueError(f"Invalid analysis value: {self.analysis}!")
 
                     denom = state2.sqrt().add_(group["eps"])
                     update_8bit = state1 / denom
@@ -415,9 +558,7 @@ def step(self, closure=None):
 
                     F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
                     F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
-                    F.histogram_scatter_add_2d(
-                        counts, C1.int(), C2.int(), torch.ones_like(abserr)
-                    )
+                    F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
 
                     p_data_fp32 += -step_size * update_fp32
 
@@ -425,18 +566,10 @@ def step(self, closure=None):
                         if self.savedir != "" and state["step"] % 100 == 0:
                             if not os.path.exists(self.savedir):
                                 os.makedirs(self.savedir)
-                            shapestr = "_".join(
-                                [str(dim) for dim in p_data_fp32.shape]
-                            )
-                            pathe = os.path.join(
-                                self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
-                            )
-                            pathrele = os.path.join(
-                                self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
-                            )
-                            pathcounts = os.path.join(
-                                self.savedir, f"{p_id}_{shapestr}_counts.pkl"
-                            )
+                            shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
+                            pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl")
+                            pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl")
+                            pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl")
                             torch.save(e, pathe)
                             torch.save(rele, pathrele)
                             torch.save(counts, pathcounts)
diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
index 1e2dc04de..fa51458fd 100644
--- a/bitsandbytes/optim/adamw.py
+++ b/bitsandbytes/optim/adamw.py
@@ -6,8 +6,21 @@
 
 
 class AdamW(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=1e-2,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         Base AdamW optimizer.
 
@@ -37,11 +50,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            optim_bits,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
+
 
 class AdamW8bit(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=1e-2,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         8-bit AdamW optimizer.
 
@@ -71,11 +111,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            8,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
+
 
 class AdamW32bit(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=1e-2,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         32-bit AdamW optimizer.
 
@@ -105,12 +172,37 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            32,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
 
 
 class PagedAdamW(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=1e-2,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+    ):
         """
         Paged AdamW optimizer.
 
@@ -140,11 +232,37 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            optim_bits,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=True,
+        )
+
 
 class PagedAdamW8bit(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=1e-2,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+    ):
         """
         Paged 8-bit AdamW optimizer.
 
@@ -174,11 +292,37 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            8,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=True,
+        )
+
 
 class PagedAdamW32bit(Optimizer2State):
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
-                       args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=1e-2,
+        amsgrad=False,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+    ):
         """
         Paged 32-bit AdamW optimizer.
 
@@ -208,4 +352,17 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
+        super().__init__(
+            "adam",
+            params,
+            lr,
+            betas,
+            eps,
+            weight_decay,
+            32,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=True,
+        )
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
index 7449b805b..63c062988 100644
--- a/bitsandbytes/optim/lars.py
+++ b/bitsandbytes/optim/lars.py
@@ -51,9 +51,7 @@ def __init__(
                 The maximum gradient norm.
         """
         if momentum == 0:
-            raise NotImplementedError(
-                "LARS without momentum is not supported!"
-            )
+            raise NotImplementedError("LARS without momentum is not supported!")
         super().__init__(
             "lars",
             params,
@@ -110,9 +108,7 @@ def __init__(
                 The maximum gradient norm.
         """
         if momentum == 0:
-            raise NotImplementedError(
-                "LARS without momentum is not supported!"
-            )
+            raise NotImplementedError("LARS without momentum is not supported!")
         super().__init__(
             "lars",
             params,
@@ -169,9 +165,7 @@ def __init__(
                 The maximum gradient norm.
         """
         if momentum == 0:
-            raise NotImplementedError(
-                "LARS without momentum is not supported!"
-            )
+            raise NotImplementedError("LARS without momentum is not supported!")
         super().__init__(
             "lars",
             params,
@@ -204,9 +198,7 @@ def __init__(
         if momentum < 0.0:
             raise ValueError(f"Invalid momentum value: {momentum}")
         if weight_decay < 0.0:
-            raise ValueError(
-                f"Invalid weight_decay value: {weight_decay}"
-            )
+            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
 
         defaults = dict(
             lr=lr,
@@ -217,9 +209,7 @@ def __init__(
             max_unorm=max_unorm,
         )
         if nesterov and (momentum <= 0 or dampening != 0):
-            raise ValueError(
-                "Nesterov momentum requires a momentum and zero dampening"
-            )
+            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
         super().__init__(params, defaults)
 
     def __setstate__(self, state):
diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py
index ce185f863..9f0f4a8a9 100644
--- a/bitsandbytes/optim/lion.py
+++ b/bitsandbytes/optim/lion.py
@@ -6,7 +6,19 @@
 
 
 class Lion(Optimizer1State):
-    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-4,
+        betas=(0.9, 0.99),
+        weight_decay=0,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         Base Lion optimizer.
 
@@ -32,10 +44,35 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bit
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
+        super().__init__(
+            "lion",
+            params,
+            lr,
+            betas,
+            0.0,
+            weight_decay,
+            optim_bits,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
+
 
 class Lion8bit(Optimizer1State):
-    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-4,
+        betas=(0.9, 0.99),
+        weight_decay=0,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         8-bit Lion optimizer.
 
@@ -59,10 +96,35 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
+        super().__init__(
+            "lion",
+            params,
+            lr,
+            betas,
+            0.0,
+            weight_decay,
+            8,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
+
 
 class Lion32bit(Optimizer1State):
-    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-4,
+        betas=(0.9, 0.99),
+        weight_decay=0,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+        is_paged=False,
+    ):
         """
         32-bit Lion optimizer.
 
@@ -86,11 +148,35 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None
             is_paged (`bool`, defaults to `False`):
                 Whether the optimizer is a paged optimizer or not.
         """
-        super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
+        super().__init__(
+            "lion",
+            params,
+            lr,
+            betas,
+            0.0,
+            weight_decay,
+            32,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=is_paged,
+        )
 
 
 class PagedLion(Optimizer1State):
-    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+    def __init__(
+        self,
+        params,
+        lr=1e-4,
+        betas=(0.9, 0.99),
+        weight_decay=0,
+        optim_bits=32,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+    ):
         """
         Paged Lion optimizer.
 
@@ -114,10 +200,34 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bit
             block_wise (`bool`, defaults to `True`):
                 Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
         """
-        super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
+        super().__init__(
+            "lion",
+            params,
+            lr,
+            betas,
+            0.0,
+            weight_decay,
+            optim_bits,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=True,
+        )
+
 
 class PagedLion8bit(Optimizer1State):
-    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+    def __init__(
+        self,
+        params,
+        lr=1e-4,
+        betas=(0.9, 0.99),
+        weight_decay=0,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+    ):
         """
         Paged 8-bit Lion optimizer.
 
@@ -141,10 +251,34 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None
             block_wise (`bool`, defaults to `True`):
                 Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
         """
-        super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
+        super().__init__(
+            "lion",
+            params,
+            lr,
+            betas,
+            0.0,
+            weight_decay,
+            8,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=True,
+        )
+
 
 class PagedLion32bit(Optimizer1State):
-    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+    def __init__(
+        self,
+        params,
+        lr=1e-4,
+        betas=(0.9, 0.99),
+        weight_decay=0,
+        args=None,
+        min_8bit_size=4096,
+        percentile_clipping=100,
+        block_wise=True,
+    ):
         """
         Paged 32-bit Lion optimizer.
 
@@ -168,4 +302,17 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None
             block_wise (`bool`, defaults to `True`):
                 Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
         """
-        super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
+        super().__init__(
+            "lion",
+            params,
+            lr,
+            betas,
+            0.0,
+            weight_decay,
+            32,
+            args,
+            min_8bit_size,
+            percentile_clipping,
+            block_wise,
+            is_paged=True,
+        )
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index a97afb026..43ebbb24d 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -21,6 +21,7 @@ class GlobalOptimManager:
     """
     A global optimizer manager for enabling custom optimizer configs.
     """
+
     _instance = None
 
     def __init__(self):
@@ -48,13 +49,9 @@ def register_parameters(self, params):
         for group_index, group in enumerate(param_groups):
             for p_index, p in enumerate(group["params"]):
                 if id(p) in self.pid2config:
-                    self.index2config[(group_index, p_index)] = self.pid2config[
-                        id(p)
-                    ]
+                    self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
 
-    def override_config(
-        self, parameters, key=None, value=None, key_value_dict=None
-    ):
+    def override_config(self, parameters, key=None, value=None, key_value_dict=None):
         """
         Override initial optimizer config with specific hyperparameters.
 
@@ -132,18 +129,18 @@ def __init__(self, params, defaults, optim_bits=32, is_paged=False):
 
         self.mng = GlobalOptimManager.get_instance()
         self.non_castable_tensor_keys = {
-                "qmap1",
-                "qmap2",
-                "max1",
-                "max2",
-                "new_max1",
-                "new_max2",
-                "state1",
-                "state2",
-                "gnorm_vec",
-                "absmax1",
-                "absmax2",
-                "unorm_vec",
+            "qmap1",
+            "qmap2",
+            "max1",
+            "max2",
+            "new_max1",
+            "new_max2",
+            "state1",
+            "state2",
+            "gnorm_vec",
+            "absmax1",
+            "absmax2",
+            "unorm_vec",
         }
 
         if optim_bits == 8:
@@ -170,16 +167,12 @@ def load_state_dict(self, state_dict):
         saved_groups = state_dict["param_groups"]
 
         if len(groups) != len(saved_groups):
-            raise ValueError(
-                "loaded state dict has a different number of "
-                "parameter groups"
-            )
+            raise ValueError("loaded state dict has a different number of parameter groups")
         param_lens = (len(g["params"]) for g in groups)
         saved_lens = (len(g["params"]) for g in saved_groups)
         if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
             raise ValueError(
-                "loaded state dict contains a parameter group "
-                "that doesn't match the size of optimizer's group"
+                "loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
             )
 
         # Update the state
@@ -228,9 +221,7 @@ def update_group(group, new_group):
             new_group["params"] = group["params"]
             return new_group
 
-        param_groups = [
-            update_group(g, ng) for g, ng in zip(groups, saved_groups)
-        ]
+        param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
         self.__setstate__({"state": state, "param_groups": param_groups})
 
     def to_gpu(self):
@@ -240,7 +231,7 @@ def to_gpu(self):
                     values = self.state[p]
                     for k, v in values.items():
                         if isinstance(v, torch.Tensor):
-                            is_paged = getattr(v, 'is_paged', False)
+                            is_paged = getattr(v, "is_paged", False)
                             if not is_paged:
                                 self.state[p][k] = v.to(p.device)
 
@@ -248,9 +239,7 @@ def check_overrides(self):
         for module, attr, config in self.mng.module_weight_config_triple:
             pmodule = getattr(module, attr)
             assert pmodule is not None
-            assert isinstance(pmodule, torch.Tensor) or isinstance(
-                pmodule, torch.Parameter
-            )
+            assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
             found = False
             for gindex, group in enumerate(self.param_groups):
                 if found:
@@ -262,9 +251,7 @@ def check_overrides(self):
                         # found the matching parameter
                         # init override
                         self.mng.pid2config[id(p)] = config
-                        self.mng.index2config[
-                            (gindex, pindex)
-                        ] = self.mng.pid2config[id(p)]
+                        self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
                         found = True
 
     @torch.no_grad()
@@ -287,7 +274,7 @@ def step(self, closure=None):
             self.to_gpu()  # needed for fairseq pure fp16 training
             self.initialized = True
 
-        #if self.is_paged: self.page_mng.prefetch_all()
+        # if self.is_paged: self.page_mng.prefetch_all()
         for gindex, group in enumerate(self.param_groups):
             for pindex, p in enumerate(group["params"]):
                 if p.grad is None:
@@ -304,7 +291,6 @@ def step(self, closure=None):
             # to sync to make sure all tensors are in the right state
             torch.cuda.synchronize()
 
-
         return loss
 
     def get_config(self, gindex, pindex, group):
@@ -328,9 +314,7 @@ def init_state(self, group, p, gindex, pindex):
         raise NotImplementedError("init_state method needs to be overridden")
 
     def update_step(self, group, p, gindex, pindex):
-        raise NotImplementedError(
-            "The update_step method needs to be overridden"
-        )
+        raise NotImplementedError("The update_step method needs to be overridden")
 
     def get_state_buffer(self, p, dtype=torch.float32):
         if not self.is_paged or p.numel() < 1e5:
@@ -345,12 +329,12 @@ def get_state_buffer(self, p, dtype=torch.float32):
     def prefetch_state(self, p):
         if self.is_paged:
             state = self.state[p]
-            s1 = state['state1']
-            is_paged = getattr(s1, 'is_paged', False)
+            s1 = state["state1"]
+            is_paged = getattr(s1, "is_paged", False)
             if is_paged:
-                F.prefetch_tensor(state['state1'])
-                if 'state2' in state:
-                    F.prefetch_tensor(state['state2'])
+                F.prefetch_tensor(state["state1"])
+                if "state2" in state:
+                    F.prefetch_tensor(state["state2"])
 
 
 class Optimizer2State(Optimizer8bit):
@@ -369,7 +353,7 @@ def __init__(
         block_wise=True,
         max_unorm=0.0,
         skip_zeros=False,
-        is_paged=False
+        is_paged=False,
     ):
         """
         Base 2-state update optimizer class.
@@ -414,13 +398,9 @@ def __init__(
             betas = [float(b) for b in betas]
         for i in range(len(betas)):
             if not 0.0 <= betas[i] < 1.0:
-                raise ValueError(
-                    f"Invalid beta parameter at index {i}: {betas[i]}"
-                )
+                raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
         if not 0.0 <= weight_decay:
-            raise ValueError(
-                f"Invalid weight_decay value: {weight_decay}"
-            )
+            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
         defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
         super().__init__(params, defaults, optim_bits, is_paged)
 
@@ -449,9 +429,7 @@ def init_state(self, group, p, gindex, pindex):
         elif config["optim_bits"] == 8:
             dtype = torch.uint8
         else:
-            raise NotImplementedError(
-                f'Amount of optimizer bits not supported: {config["optim_bits"]}'
-            )
+            raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
 
         if p.numel() < config["min_8bit_size"]:
             dtype = torch.float32
@@ -459,21 +437,15 @@ def init_state(self, group, p, gindex, pindex):
         state = self.state[p]
         state["step"] = 0
 
-        if dtype == torch.float32 or (
-            dtype == torch.uint8 and p.numel() < 4096
-        ):
+        if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
             state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
             state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
         elif dtype == torch.uint8:
             if state["step"] == 0:
                 if "dynamic" not in self.name2qmap:
                     self.fill_qmap()
-                self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
-                    p.device
-                )
-                self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
-                    p.device
-                )
+                self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
+                self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
 
             state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
             state["qmap1"] = self.name2qmap["dynamic"]
@@ -486,25 +458,13 @@ def init_state(self, group, p, gindex, pindex):
                 blocks = n // 2048
                 blocks += 1 if n % 2048 > 0 else 0
 
-                state["absmax1"] = torch.zeros(
-                    (blocks,), dtype=torch.float32, device=p.device
-                )
-                state["absmax2"] = torch.zeros(
-                    (blocks,), dtype=torch.float32, device=p.device
-                )
+                state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+                state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
             else:
-                state["max1"] = torch.zeros(
-                    (1,), dtype=torch.float32, device=p.device
-                )
-                state["new_max1"] = torch.zeros(
-                    (1,), dtype=torch.float32, device=p.device
-                )
-                state["max2"] = torch.zeros(
-                    (1,), dtype=torch.float32, device=p.device
-                )
-                state["new_max2"] = torch.zeros(
-                    (1,), dtype=torch.float32, device=p.device
-                )
+                state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+                state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+                state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+                state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
 
         if config["percentile_clipping"] < 100:
             state["gnorm_vec"] = torch.zeros((100,), device=p.device)
@@ -524,7 +484,10 @@ def update_step(self, group, p, gindex, pindex):
 
         if config["percentile_clipping"] < 100:
             current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
-                grad, state["gnorm_vec"], step, config["percentile_clipping"]
+                grad,
+                state["gnorm_vec"],
+                step,
+                config["percentile_clipping"],
             )
         else:
             gnorm_scale = 1.0
@@ -568,9 +531,7 @@ def update_step(self, group, p, gindex, pindex):
                 state["new_max2"],
                 config["weight_decay"],
                 gnorm_scale=gnorm_scale,
-                unorm_vec=state["unorm_vec"]
-                if config["max_unorm"] > 0.0
-                else None,
+                unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
                 max_unorm=config["max_unorm"],
             )
 
@@ -615,7 +576,7 @@ def __init__(
         block_wise=True,
         max_unorm=0.0,
         skip_zeros=False,
-        is_paged=False
+        is_paged=False,
     ):
         """
         Base 1-state update optimizer class.
@@ -656,13 +617,9 @@ def __init__(
             raise ValueError(f"Invalid epsilon value: {eps}")
         for i in range(len(betas)):
             if not 0.0 <= betas[i] < 1.0:
-                raise ValueError(
-                    f"Invalid beta parameter at index {i}: {betas[i]}"
-                )
+                raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
         if not 0.0 <= weight_decay:
-            raise ValueError(
-                f"Invalid weight_decay value: {weight_decay}"
-            )
+            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
         defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
         super().__init__(params, defaults, optim_bits, is_paged)
 
@@ -691,9 +648,7 @@ def init_state(self, group, p, gindex, pindex):
         elif config["optim_bits"] == 8:
             dtype = torch.uint8
         else:
-            raise NotImplementedError(
-                f'Amount of optimizer bits not supported: {config["optim_bits"]}'
-            )
+            raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
 
         if p.numel() < config["min_8bit_size"]:
             dtype = torch.float32
@@ -701,17 +656,13 @@ def init_state(self, group, p, gindex, pindex):
         state = self.state[p]
         state["step"] = 0
 
-        if dtype == torch.float32 or (
-            dtype == torch.uint8 and p.numel() < 4096
-        ):
+        if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
             state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
         elif dtype == torch.uint8:
             if state["step"] == 0:
                 if "dynamic" not in self.name2qmap:
                     self.fill_qmap()
-                self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
-                    p.device
-                )
+                self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
 
             state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
             state["qmap1"] = self.name2qmap["dynamic"]
@@ -721,16 +672,10 @@ def init_state(self, group, p, gindex, pindex):
                 blocks = n // 2048
                 blocks += 1 if n % 2048 > 0 else 0
 
-                state["absmax1"] = torch.zeros(
-                    (blocks,), dtype=torch.float32, device=p.device
-                )
+                state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
             else:
-                state["max1"] = torch.zeros(
-                    (1,), dtype=torch.float32, device=p.device
-                )
-                state["new_max1"] = torch.zeros(
-                    (1,), dtype=torch.float32, device=p.device
-                )
+                state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+                state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
 
         if config["percentile_clipping"] < 100:
             state["gnorm_vec"] = torch.zeros((100,), device=p.device)
@@ -750,7 +695,10 @@ def update_step(self, group, p, gindex, pindex):
 
         if config["percentile_clipping"] < 100:
             current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
-                grad, state["gnorm_vec"], step, config["percentile_clipping"]
+                grad,
+                state["gnorm_vec"],
+                step,
+                config["percentile_clipping"],
             )
         else:
             gnorm_scale = 1.0
@@ -766,7 +714,7 @@ def update_step(self, group, p, gindex, pindex):
                 step,
                 config["lr"],
                 None,
-                config['betas'][1],
+                config["betas"][1],
                 config["weight_decay"],
                 gnorm_scale,
                 state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py
index ac371a66f..659617654 100644
--- a/bitsandbytes/optim/rmsprop.py
+++ b/bitsandbytes/optim/rmsprop.py
@@ -51,9 +51,7 @@ def __init__(
                 Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
         """
         if alpha == 0:
-            raise NotImplementedError(
-                "RMSprop with alpha==0.0 is not supported!"
-            )
+            raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
         if centered:
             raise NotImplementedError("Centered RMSprop is not supported!")
         super().__init__(
@@ -116,9 +114,7 @@ def __init__(
                 Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
         """
         if alpha == 0:
-            raise NotImplementedError(
-                "RMSprop with alpha==0.0 is not supported!"
-            )
+            raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
         if centered:
             raise NotImplementedError("Centered RMSprop is not supported!")
         super().__init__(
@@ -182,9 +178,7 @@ def __init__(
         """
 
         if alpha == 0:
-            raise NotImplementedError(
-                "RMSprop with alpha==0.0 is not supported!"
-            )
+            raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
         if centered:
             raise NotImplementedError("Centered RMSprop is not supported!")
         super().__init__(
diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py
index 7d869e39a..b194b8777 100644
--- a/bitsandbytes/research/autograd/_functions.py
+++ b/bitsandbytes/research/autograd/_functions.py
@@ -195,9 +195,9 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):  # noqa: B00
             ctx.B = B
             ctx.bias = bias
             if A.shape[-1] == B.shape[0]:
-                return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
+                return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
             else:
-                return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
+                return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
 
         # 1. Quantize A
         # 2. Quantize B
@@ -216,9 +216,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):  # noqa: B00
         # 1. Quantize A
         if len(A.shape) == 3:
             A = A.view(-1, A.shape[-1]).contiguous()
-        CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
-            A.to(torch.float16), threshold=state.threshold
-        )
+        CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
 
         if state.threshold > 0.0 and coo_tensorA is not None:
             if state.has_fp16_weights:
@@ -234,14 +232,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):  # noqa: B00
                     # we also need to convert it to the turing/ampere format
                     state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
         else:
-            #print('A shape', A.shape)
+            # print('A shape', A.shape)
             if not state.has_fp16_weights and state.CxB is None:
                 state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
             subA = None
 
         # 2. Quantize B
         if state.has_fp16_weights:
-            #print('B shape', B.shape)
+            # print('B shape', B.shape)
             has_grad = True if (getattr(B, "grad", None) is not None) else False
             is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
             if is_transposed:
@@ -272,12 +270,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):  # noqa: B00
             # else:
             #    state.idx = outlier_idx
             outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
-            state.subB = (
-                (outliers * state.SCB.view(-1, 1) / 127.0)
-                .t()
-                .contiguous()
-                .to(A.dtype)
-            )
+            state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
             CA[:, state.idx.long()] = 0
             CAt[:, state.idx.long()] = 0
             subA = A[:, state.idx.long()]
@@ -320,14 +313,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):  # noqa: B00
             ctx.tensor_states = (None, None)
             ctx.save_for_backward(None, None)
 
-
-        clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
+        clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
         return clone_func(output.view(output_shape))
 
     @staticmethod
     def backward(ctx, grad_output):
         if ctx.is_empty:
-            bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
+            bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
             return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
         req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
         CAt, subA, A = ctx.tensors
@@ -342,9 +334,7 @@ def backward(ctx, grad_output):
 
         # Cast grad_output to fp16
         if len(grad_output.shape) == 3:
-            grad_output = grad_output.reshape(
-                -1, grad_output.shape[-1]
-            ).contiguous()
+            grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
 
         Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
 
@@ -357,25 +347,24 @@ def backward(ctx, grad_output):
             if state.CBt is not None:
                 C32grad, Sgrad = F.transform(Cgrad, "col32")
                 if state.CxBt is None:
-                    state.CxBt, state.SBt = F.transform(
-                        state.CBt, to_order=formatB, transpose=True
-                    )
+                    state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
                 # print('back B shape', state.CxBt.shape)
                 # print('back grad shape', C32grad.shape)
                 gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
                 grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
 
             elif state.CB is not None:
-                CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
+                CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
                 grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
             else:
-                raise Exception('State must contain either CBt or CB matrix for backward')
+                raise Exception("State must contain either CBt or CB matrix for backward")
 
         return grad_A, grad_B, None, grad_bias, None
 
+
 def get_block_sizes(input_matrix, weight_matrix):
     input_features = input_matrix.shape[-1]
-    output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1])
+    output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]
     array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
     bsz, bsz2 = 1024, 1024
     for i, k in enumerate(array):
@@ -399,7 +388,8 @@ def matmul_fp8_global(
     bsz: int = -1,
     bsz2: int = -1,
 ):
-    if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
+    if bsz == -1 or bsz2 == -1:
+        bsz, bsz2 = get_block_sizes(A, B)
     return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
 
 
@@ -412,7 +402,8 @@ def matmul_fp8_mixed(
     bsz: int = -1,
     bsz2: int = -1,
 ):
-    if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
+    if bsz == -1 or bsz2 == -1:
+        bsz, bsz2 = get_block_sizes(A, B)
     return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
 
 
@@ -422,7 +413,7 @@ def switchback_bnb(
     out: Optional[torch.Tensor] = None,
     state: Optional[MatmulLtState] = None,
     threshold=0.0,
-    bias=None
+    bias=None,
 ):
     state = state or MatmulLtState()
     if threshold > 0.0:
diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py
index 7fca34d23..57c0f3358 100644
--- a/bitsandbytes/research/nn/modules.py
+++ b/bitsandbytes/research/nn/modules.py
@@ -28,12 +28,20 @@ def forward(self, x: torch.Tensor):
             self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
             self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
 
-        out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
+        out = bnb.research.matmul_fp8_mixed(
+            x,
+            self.weight.t(),
+            fw_code=self.fw_code,
+            bw_code=self.bw_code,
+            bsz=self.bsz,
+            bsz2=self.bsz2,
+        )
         if self.bias is not None:
             out += self.bias
 
         return out
 
+
 class LinearFP8Global(nn.Linear):
     def __init__(self, input_features, output_features, bias=True):
         super().__init__(input_features, output_features, bias)
@@ -54,7 +62,14 @@ def forward(self, x: torch.Tensor):
             self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
             self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
 
-        out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
+        out = bnb.matmul_fp8_global(
+            x,
+            self.weight.t(),
+            fw_code=self.fw_code,
+            bw_code=self.bw_code,
+            bsz=self.bsz,
+            bsz2=self.bsz2,
+        )
         if self.bias is not None:
             out += self.bias
 
diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py
index 3d7529852..26eab84f2 100644
--- a/bitsandbytes/triton/dequantize_rowwise.py
+++ b/bitsandbytes/triton/dequantize_rowwise.py
@@ -5,9 +5,10 @@
 from bitsandbytes.triton.triton_utils import is_triton_available
 
 if not is_triton_available():
-    def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
-else:
 
+    def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
+        return None
+else:
     import triton
     import triton.language as tl
 
@@ -15,21 +16,21 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
 
     # TODO: autotune this better.
     @triton.autotune(
-            configs=[
-                triton.Config({}, num_stages=1, num_warps=8),
-                triton.Config({}, num_stages=2, num_warps=8),
-                triton.Config({}, num_stages=4, num_warps=8),
-                triton.Config({}, num_stages=8, num_warps=8),
-                triton.Config({}, num_stages=1),
-                triton.Config({}, num_stages=2),
-                triton.Config({}, num_stages=4),
-                triton.Config({}, num_stages=8),
-                triton.Config({}, num_warps=1),
-                triton.Config({}, num_warps=2),
-                triton.Config({}, num_warps=4),
-                triton.Config({}, num_warps=8),
-            ],
-            key=['n_elements']
+        configs=[
+            triton.Config({}, num_stages=1, num_warps=8),
+            triton.Config({}, num_stages=2, num_warps=8),
+            triton.Config({}, num_stages=4, num_warps=8),
+            triton.Config({}, num_stages=8, num_warps=8),
+            triton.Config({}, num_stages=1),
+            triton.Config({}, num_stages=2),
+            triton.Config({}, num_stages=4),
+            triton.Config({}, num_stages=8),
+            triton.Config({}, num_warps=1),
+            triton.Config({}, num_warps=2),
+            triton.Config({}, num_warps=4),
+            triton.Config({}, num_warps=8),
+        ],
+        key=["n_elements"],
     )
     @triton.jit
     def _dequantize_rowwise(
@@ -51,7 +52,6 @@ def _dequantize_rowwise(
         output = max_val * x * inv_127
         tl.store(output_ptr + offsets, output, mask=row_mask)
 
-
     def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
         output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
 
@@ -60,5 +60,5 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
         assert x.is_cuda and output.is_cuda
         n_elements = output.numel()
         grid = lambda meta: (x.shape[0],)
-        _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
+        _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
         return output
diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py
index dc3047d7e..583371d91 100644
--- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py
+++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py
@@ -3,14 +3,14 @@
 from bitsandbytes.triton.triton_utils import is_triton_available
 
 if not is_triton_available():
-    def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
-else:
 
+    def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
+        return None
+else:
     import triton
     import triton.language as tl
     from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
 
-
     # This is a matmul kernel based on triton.ops.matmul
     # It is modified to support rowwise quantized input and global quantized weight
     # It's purpose is fused matmul then dequantize
@@ -27,58 +27,83 @@ def get_configs_io_bound():
                     for block_n in [32, 64, 128, 256]:
                         num_warps = 2 if block_n <= 64 else 4
                         configs.append(
-                            triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
-                                          num_stages=num_stages, num_warps=num_warps))
+                            triton.Config(
+                                {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
+                                num_stages=num_stages,
+                                num_warps=num_warps,
+                            ),
+                        )
                         # split_k
                         for split_k in [2, 4, 8, 16]:
-                            configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
-                                                         num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
+                            configs.append(
+                                triton.Config(
+                                    {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
+                                    num_stages=num_stages,
+                                    num_warps=num_warps,
+                                    pre_hook=init_to_zero("C"),
+                                ),
+                            )
         return configs
 
-
     @triton.autotune(
         configs=[
             # basic configs for compute-bound matmuls
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
-            triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
-            triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
             # good for int8
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
-            triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
-            triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
             *get_configs_io_bound(),
         ],
-        key=['M', 'N', 'K'],
-        prune_configs_by={
-            'early_config_prune': early_config_prune,
-            'perf_model': estimate_matmul_time,
-            'top_k': 10
+        key=["M", "N", "K"],
+        prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
+    )
+    @triton.heuristics(
+        {
+            "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
         },
     )
-    @triton.heuristics({
-        'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
-    })
     @triton.jit
-    def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
-                stride_am, stride_ak,
-                stride_bk, stride_bn,
-                stride_cm, stride_cn,
-                BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
-                GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
-                ACC_TYPE: tl.constexpr
-                ):
+    def _int8_matmul_mixed_dequantize(
+        A,
+        B,
+        C,
+        bias,
+        state_x_ptr,
+        state_w_ptr,
+        M,
+        N,
+        K,
+        divfactor: tl.constexpr,
+        has_bias: tl.constexpr,
+        stride_am,
+        stride_ak,
+        stride_bk,
+        stride_bn,
+        stride_cm,
+        stride_cn,
+        BLOCK_M: tl.constexpr,
+        BLOCK_N: tl.constexpr,
+        BLOCK_K: tl.constexpr,
+        GROUP_M: tl.constexpr,
+        SPLIT_K: tl.constexpr,
+        EVEN_K: tl.constexpr,
+        ACC_TYPE: tl.constexpr,
+    ):
         # matrix multiplication
         pid = tl.program_id(0)
         pid_z = tl.program_id(1)
@@ -115,13 +140,13 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N,
                 b = tl.load(B)
             else:
                 k_remaining = K - k * (BLOCK_K * SPLIT_K)
-                a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
-                b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
+                a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
+                b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
             acc += tl.dot(a, b)
             A += BLOCK_K * SPLIT_K * stride_ak
             B += BLOCK_K * SPLIT_K * stride_bk
 
-        acc = (w_factor * (x_factor * (acc * divfactor)))
+        acc = w_factor * (x_factor * (acc * divfactor))
         acc = acc.to(C.dtype.element_ty)
 
         # conditionally add bias
@@ -137,10 +162,9 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N,
         else:
             tl.atomic_add(C, acc, mask=mask)
 
-
     def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
         device = a.device
-        divfactor = 1. / (127. * 127.)
+        divfactor = 1.0 / (127.0 * 127.0)
         has_bias = 0 if bias is None else 1
         # handle non-contiguous inputs if necessary
         if a.stride(0) > 1 and a.stride(1) > 1:
@@ -154,12 +178,28 @@ def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
         # allocates output
         c = torch.empty((M, N), device=device, dtype=torch.float16)
         # accumulator types
-        ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+        ACC_TYPE = tl.float32  # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
         # launch int8_matmul_mixed_dequantize kernel
-        grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
-        _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
-                        a.stride(0), a.stride(1),
-                        b.stride(0), b.stride(1),
-                        c.stride(0), c.stride(1),
-                        GROUP_M=8, ACC_TYPE=ACC_TYPE)
+        grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
+        _int8_matmul_mixed_dequantize[grid](
+            a,
+            b,
+            c,
+            bias,
+            state_x,
+            state_w,
+            M,
+            N,
+            K,
+            divfactor,
+            has_bias,
+            a.stride(0),
+            a.stride(1),
+            b.stride(0),
+            b.stride(1),
+            c.stride(0),
+            c.stride(1),
+            GROUP_M=8,
+            ACC_TYPE=ACC_TYPE,
+        )
         return c
diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
index 4881e1468..e3d192ded 100644
--- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
+++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
@@ -3,7 +3,9 @@
 from bitsandbytes.triton.triton_utils import is_triton_available
 
 if not is_triton_available():
-    def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
+
+    def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
+        return None
 else:
     import triton
     import triton.language as tl
@@ -17,7 +19,6 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
     def init_to_zero(name):
         return lambda nargs: nargs[name].zero_()
 
-
     def get_configs_io_bound():
         configs = []
         for num_stages in [2, 3, 4, 5, 6]:
@@ -26,58 +27,83 @@ def get_configs_io_bound():
                     for block_n in [32, 64, 128, 256]:
                         num_warps = 2 if block_n <= 64 else 4
                         configs.append(
-                            triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
-                                          num_stages=num_stages, num_warps=num_warps))
+                            triton.Config(
+                                {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
+                                num_stages=num_stages,
+                                num_warps=num_warps,
+                            ),
+                        )
                         # split_k
                         for split_k in [2, 4, 8, 16]:
-                            configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
-                                                         num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
+                            configs.append(
+                                triton.Config(
+                                    {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
+                                    num_stages=num_stages,
+                                    num_warps=num_warps,
+                                    pre_hook=init_to_zero("C"),
+                                ),
+                            )
         return configs
 
-
     @triton.autotune(
         configs=[
             # basic configs for compute-bound matmuls
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
-            triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
-            triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
             # good for int8
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
-            triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
-            triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
-            triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
             *get_configs_io_bound(),
         ],
-        key=['M', 'N', 'K'],
-        prune_configs_by={
-            'early_config_prune': early_config_prune,
-            'perf_model': estimate_matmul_time,
-            'top_k': 10
+        key=["M", "N", "K"],
+        prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
+    )
+    @triton.heuristics(
+        {
+            "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
         },
     )
-    @triton.heuristics({
-        'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
-    })
     @triton.jit
-    def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
-                stride_am, stride_ak,
-                stride_bk, stride_bn,
-                stride_cm, stride_cn,
-                BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
-                GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
-                ACC_TYPE: tl.constexpr
-                ):
+    def _int8_matmul_rowwise_dequantize(
+        A,
+        B,
+        C,
+        bias,
+        state_x_ptr,
+        state_w_ptr,
+        M,
+        N,
+        K,
+        divfactor,
+        has_bias: tl.constexpr,
+        stride_am,
+        stride_ak,
+        stride_bk,
+        stride_bn,
+        stride_cm,
+        stride_cn,
+        BLOCK_M: tl.constexpr,
+        BLOCK_N: tl.constexpr,
+        BLOCK_K: tl.constexpr,
+        GROUP_M: tl.constexpr,
+        SPLIT_K: tl.constexpr,
+        EVEN_K: tl.constexpr,
+        ACC_TYPE: tl.constexpr,
+    ):
         # matrix multiplication
         pid = tl.program_id(0)
         pid_z = tl.program_id(1)
@@ -114,13 +140,13 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M,
                 b = tl.load(B)
             else:
                 k_remaining = K - k * (BLOCK_K * SPLIT_K)
-                a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
-                b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
+                a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
+                b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
             acc += tl.dot(a, b)
             A += BLOCK_K * SPLIT_K * stride_ak
             B += BLOCK_K * SPLIT_K * stride_bk
 
-        acc = (w_factor * (x_factor * (acc * divfactor)))
+        acc = w_factor * (x_factor * (acc * divfactor))
         acc = acc.to(C.dtype.element_ty)
 
         if has_bias:
@@ -135,9 +161,8 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M,
         else:
             tl.atomic_add(C, acc, mask=mask)
 
-
     def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
-        divfactor = 1. / (127. * 127.)
+        divfactor = 1.0 / (127.0 * 127.0)
 
         has_bias = 0 if bias is None else 1
 
@@ -154,12 +179,28 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
         # allocates output
         c = torch.empty((M, N), device=device, dtype=torch.float16)
         # accumulator types
-        ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
+        ACC_TYPE = tl.float32  # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
         # launch int8_matmul_rowwise_dequantize kernel
-        grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
-        _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
-                        a.stride(0), a.stride(1),
-                        b.stride(0), b.stride(1),
-                        c.stride(0), c.stride(1),
-                        GROUP_M=8, ACC_TYPE=ACC_TYPE)
+        grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
+        _int8_matmul_rowwise_dequantize[grid](
+            a,
+            b,
+            c,
+            bias,
+            state_x,
+            state_w,
+            M,
+            N,
+            K,
+            divfactor,
+            has_bias,
+            a.stride(0),
+            a.stride(1),
+            b.stride(0),
+            b.stride(1),
+            c.stride(0),
+            c.stride(1),
+            GROUP_M=8,
+            ACC_TYPE=ACC_TYPE,
+        )
         return c
diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py
index e7961cf53..b8eeffd0c 100644
--- a/bitsandbytes/triton/quantize_columnwise_and_transpose.py
+++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py
@@ -5,9 +5,10 @@
 from bitsandbytes.triton.triton_utils import is_triton_available
 
 if not is_triton_available():
-    def quantize_columnwise_and_transpose(x: torch.Tensor): return None
-else:
 
+    def quantize_columnwise_and_transpose(x: torch.Tensor):
+        return None
+else:
     import triton
     import triton.language as tl
 
@@ -15,23 +16,23 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): return None
 
     # TODO: autotune this better.
     @triton.autotune(
-            configs=[
-                triton.Config({}, num_stages=1),
-                triton.Config({}, num_stages=2),
-                triton.Config({}, num_stages=4),
-                triton.Config({}, num_stages=8),
-                triton.Config({}, num_stages=16),
-                triton.Config({}, num_stages=1, num_warps=8),
-                triton.Config({}, num_stages=2, num_warps=8),
-                triton.Config({}, num_stages=4, num_warps=8),
-                triton.Config({}, num_stages=8, num_warps=8),
-                triton.Config({}, num_stages=16, num_warps=8),
-                triton.Config({}, num_warps=1),
-                triton.Config({}, num_warps=2),
-                triton.Config({}, num_warps=4),
-                triton.Config({}, num_warps=8),
-            ],
-            key=['n_elements']
+        configs=[
+            triton.Config({}, num_stages=1),
+            triton.Config({}, num_stages=2),
+            triton.Config({}, num_stages=4),
+            triton.Config({}, num_stages=8),
+            triton.Config({}, num_stages=16),
+            triton.Config({}, num_stages=1, num_warps=8),
+            triton.Config({}, num_stages=2, num_warps=8),
+            triton.Config({}, num_stages=4, num_warps=8),
+            triton.Config({}, num_stages=8, num_warps=8),
+            triton.Config({}, num_stages=16, num_warps=8),
+            triton.Config({}, num_warps=1),
+            triton.Config({}, num_warps=2),
+            triton.Config({}, num_warps=4),
+            triton.Config({}, num_warps=8),
+        ],
+        key=["n_elements"],
     )
     @triton.jit
     def _quantize_columnwise_and_transpose(
@@ -39,7 +40,8 @@ def _quantize_columnwise_and_transpose(
         output_ptr,
         output_maxs,
         n_elements,
-        M : tl.constexpr, N : tl.constexpr,
+        M: tl.constexpr,
+        N: tl.constexpr,
         BLOCK_SIZE: tl.constexpr,
         P2: tl.constexpr,
     ):
@@ -47,12 +49,12 @@ def _quantize_columnwise_and_transpose(
         block_start = pid
         p2_arange = tl.arange(0, P2)
         p2_arange_mask = p2_arange < M
-        arange =  p2_arange * N
+        arange = p2_arange * N
         offsets = block_start + arange
         x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
         abs_x = tl.abs(x)
         max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
-        output = tl.libdevice.llrint(127. * (x / max_val))
+        output = tl.libdevice.llrint(127.0 * (x / max_val))
 
         new_start = pid * M
         new_offsets = new_start + p2_arange
@@ -68,6 +70,6 @@ def quantize_columnwise_and_transpose(x: torch.Tensor):
 
         assert x.is_cuda and output.is_cuda
         n_elements = output.numel()
-        grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
         _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
         return output, output_maxs
diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py
index 5cf194744..f35bdd304 100644
--- a/bitsandbytes/triton/quantize_global.py
+++ b/bitsandbytes/triton/quantize_global.py
@@ -1,24 +1,25 @@
-
 import torch
 
 from bitsandbytes.triton.triton_utils import is_triton_available
 
 if not is_triton_available():
-    def quantize_global_transpose(input): return None
-    def quantize_global(x: torch.Tensor): return None
-else:
 
+    def quantize_global_transpose(input):
+        return None
+
+    def quantize_global(x: torch.Tensor):
+        return None
+else:
     import triton
     import triton.language as tl
 
     # global quantize
     @triton.autotune(
-            configs=[
-                triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
-                triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
-
-            ],
-            key=['n_elements']
+        configs=[
+            triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
+            triton.Config({"BLOCK_SIZE": 2048}, num_stages=1),
+        ],
+        key=["n_elements"],
     )
     @triton.jit
     def _quantize_global(
@@ -34,35 +35,43 @@ def _quantize_global(
         mask = offsets < n_elements
         x = tl.load(x_ptr + offsets, mask=mask)
         absmax_inv = tl.load(absmax_inv_ptr)
-        output = tl.libdevice.llrint(127. * (x * absmax_inv))
+        output = tl.libdevice.llrint(127.0 * (x * absmax_inv))
         tl.store(output_ptr + offsets, output, mask=mask)
 
     def quantize_global(x: torch.Tensor):
         absmax = x.abs().max().unsqueeze(0)
-        absmax_inv = 1./ absmax
-        output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
+        absmax_inv = 1.0 / absmax
+        output = torch.empty(*x.shape, device="cuda", dtype=torch.int8)
         assert x.is_cuda and output.is_cuda
         n_elements = output.numel()
-        grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
         _quantize_global[grid](x, absmax_inv, output, n_elements)
         return output, absmax
 
-
     # global quantize and transpose
     @triton.autotune(
-            configs=[
-                triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
-                triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
-
-                # ...
-            ],
-            key=['M', 'N']
+        configs=[
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
+            triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
+            # ...
+        ],
+        key=["M", "N"],
     )
     @triton.jit
-    def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
-                          BLOCK_M : tl.constexpr,
-                          BLOCK_N : tl.constexpr,
-                          GROUP_M : tl.constexpr):
+    def _quantize_global_transpose(
+        A,
+        absmax_inv_ptr,
+        B,
+        stride_am,
+        stride_an,
+        stride_bn,
+        stride_bm,
+        M,
+        N,
+        BLOCK_M: tl.constexpr,
+        BLOCK_N: tl.constexpr,
+        GROUP_M: tl.constexpr,
+    ):
         pid = tl.program_id(0)
         grid_m = (M + BLOCK_M - 1) // BLOCK_M
         grid_n = (N + BLOCK_N - 1) // BLOCK_N
@@ -86,20 +95,30 @@ def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, strid
         B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
         mask = (rm < M)[:, None] & (rn < N)[None, :]
 
-        output = tl.libdevice.llrint(127. * (a * absmax_inv))
+        output = tl.libdevice.llrint(127.0 * (a * absmax_inv))
 
         tl.store(B, output, mask=mask)
 
     def quantize_global_transpose(input):
         absmax = input.abs().max().unsqueeze(0)
-        absmax_inv = 1./ absmax
+        absmax_inv = 1.0 / absmax
         M, N = input.shape
-        out = torch.empty(N, M, device='cuda', dtype=torch.int8)
+        out = torch.empty(N, M, device="cuda", dtype=torch.int8)
 
         assert out.size(0) == N and out.size(1) == M
         assert input.stride(0) == 1 or input.stride(1) == 1
         assert out.stride(0) == 1 or out.stride(1) == 1
 
-        grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
-        _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
+        grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+        _quantize_global_transpose[grid](
+            input,
+            absmax_inv,
+            out,
+            input.stride(0),
+            input.stride(1),
+            out.stride(0),
+            out.stride(1),
+            M,
+            N,
+        )
         return out, absmax
diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py
index 078f4aa2d..f92ace02c 100644
--- a/bitsandbytes/triton/quantize_rowwise.py
+++ b/bitsandbytes/triton/quantize_rowwise.py
@@ -5,9 +5,10 @@
 from bitsandbytes.triton.triton_utils import is_triton_available
 
 if not is_triton_available():
-    def quantize_rowwise(x: torch.Tensor): return None
-else:
 
+    def quantize_rowwise(x: torch.Tensor):
+        return None
+else:
     import triton
     import triton.language as tl
 
@@ -15,21 +16,21 @@ def quantize_rowwise(x: torch.Tensor): return None
 
     # TODO: autotune this better.
     @triton.autotune(
-            configs=[
-                triton.Config({}, num_stages=1, num_warps=8),
-                triton.Config({}, num_stages=2, num_warps=8),
-                triton.Config({}, num_stages=4, num_warps=8),
-                triton.Config({}, num_stages=8, num_warps=8),
-                triton.Config({}, num_stages=1),
-                triton.Config({}, num_stages=2),
-                triton.Config({}, num_stages=4),
-                triton.Config({}, num_stages=8),
-                triton.Config({}, num_warps=1),
-                triton.Config({}, num_warps=2),
-                triton.Config({}, num_warps=4),
-                triton.Config({}, num_warps=8),
-            ],
-            key=['n_elements']
+        configs=[
+            triton.Config({}, num_stages=1, num_warps=8),
+            triton.Config({}, num_stages=2, num_warps=8),
+            triton.Config({}, num_stages=4, num_warps=8),
+            triton.Config({}, num_stages=8, num_warps=8),
+            triton.Config({}, num_stages=1),
+            triton.Config({}, num_stages=2),
+            triton.Config({}, num_stages=4),
+            triton.Config({}, num_stages=8),
+            triton.Config({}, num_warps=1),
+            triton.Config({}, num_warps=2),
+            triton.Config({}, num_warps=4),
+            triton.Config({}, num_warps=8),
+        ],
+        key=["n_elements"],
     )
     @triton.jit
     def _quantize_rowwise(
@@ -49,7 +50,7 @@ def _quantize_rowwise(
 
         abs_x = tl.abs(x)
         max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
-        output = tl.libdevice.llrint(127. * (x / max_val))
+        output = tl.libdevice.llrint(127.0 * (x / max_val))
         tl.store(output_ptr + offsets, output, mask=row_mask)
         tl.store(output_maxs + pid, max_val)
 
diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py
index 0582f7fc0..48c7fc82d 100644
--- a/bitsandbytes/utils.py
+++ b/bitsandbytes/utils.py
@@ -30,7 +30,7 @@ def outlier_hook(module, input):
             # (1) zscore test of std of hidden dimension
             outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3)
             # (2) magnitude > 6 test
-            dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1)))
+            dims = (torch.abs(input[0]) > 6).sum(dim=list(range(len(input[0].shape) - 1)))
             outlier_idx2 = torch.where(dims > 0)[0]
             outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique()
             tracer.hvalue2outlier_idx[hvalue] = outlier_idx
@@ -59,14 +59,14 @@ def initialize(self, model):
                 self.hooks.append(m.register_forward_pre_hook(outlier_hook))
 
     def is_initialized(self):
-        return getattr(self, 'initialized', False)
+        return getattr(self, "initialized", False)
 
     def get_hvalue(self, weight):
         return weight.data.storage().data_ptr()
 
     def get_outliers(self, weight):
         if not self.is_initialized():
-            print('Outlier tracer is not initialized...')
+            print("Outlier tracer is not initialized...")
             return None
         hvalue = self.get_hvalue(weight)
         if hvalue in self.hvalue2outlier_idx:
@@ -80,6 +80,7 @@ def get_instance(cls):
             cls._instance = cls.__new__(cls)
         return cls._instance
 
+
 def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False):
     if rdm:
         return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long()
@@ -87,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
     m = weight.mean(reduction_dim)
     mm = m.mean()
     mstd = m.std()
-    zm = (m-mm)/mstd
+    zm = (m - mm) / mstd
 
     std = weight.std(reduction_dim)
     stdm = std.mean()
     stdstd = std.std()
 
-    zstd = (std-stdm)/stdstd
+    zstd = (std - stdm) / stdstd
 
     if topk is not None:
         val, idx = torch.topk(std.abs(), k=topk, dim=0)
@@ -105,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
 
 def execute_and_return(command_string: str) -> Tuple[str, str]:
     def _decode(subprocess_err_out_tuple):
-        return tuple(
-            to_decode.decode("UTF-8").strip()
-            for to_decode in subprocess_err_out_tuple
-        )
+        return tuple(to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple)
 
     def execute_and_return_decoded_std_streams(command_string):
         return _decode(
@@ -116,14 +114,13 @@ def execute_and_return_decoded_std_streams(command_string):
                 shlex.split(command_string),
                 stdout=subprocess.PIPE,
                 stderr=subprocess.PIPE,
-            ).communicate()
+            ).communicate(),
         )
 
     std_out, std_err = execute_and_return_decoded_std_streams(command_string)
     return std_out, std_err
 
 
-
 def replace_linear(
     model,
     linear_replacement,
@@ -163,8 +160,9 @@ def replace_linear(
                 model._modules[name].bias = old_module.bias
 
             if post_processing_function is not None:
-               func = getattr(module, post_processing_function, None)
-               if func is not None: func(module)
+                func = getattr(module, post_processing_function, None)
+                if func is not None:
+                    func(module)
     return model
 
 
@@ -179,7 +177,7 @@ def pack_dict_to_tensor(source_dict):
     A torch tensor containing the packed data.
     """
     json_str = json.dumps(source_dict)
-    json_bytes = json_str.encode('utf-8')
+    json_bytes = json_str.encode("utf-8")
     tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8)
 
     return tensor_data
@@ -196,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data):
     A Python dictionary containing the unpacked data.
     """
     json_bytes = bytes(tensor_data.cpu().numpy())
-    json_str = json_bytes.decode('utf-8')
+    json_str = json_bytes.decode("utf-8")
     unpacked_dict = json.loads(json_str)
 
     return unpacked_dict
diff --git a/check_bnb_install.py b/check_bnb_install.py
index 5a7f74f89..7a9dc93fc 100644
--- a/check_bnb_install.py
+++ b/check_bnb_install.py
@@ -2,14 +2,14 @@
 
 import bitsandbytes as bnb
 
-p = torch.nn.Parameter(torch.rand(10,10).cuda())
-a = torch.rand(10,10).cuda()
+p = torch.nn.Parameter(torch.rand(10, 10).cuda())
+a = torch.rand(10, 10).cuda()
 
 p1 = p.data.sum().item()
 
 adam = bnb.optim.Adam([p])
 
-out = a*p
+out = a * p
 loss = out.sum()
 loss.backward()
 adam.step()
@@ -17,5 +17,5 @@
 p2 = p.data.sum().item()
 
 assert p1 != p2
-print('SUCCESS!')
-print('Installation was successful!')
+print("SUCCESS!")
+print("Installation was successful!")
diff --git a/examples/int8_inference_huggingface.py b/examples/int8_inference_huggingface.py
index c89ba8d11..2d4c77952 100644
--- a/examples/int8_inference_huggingface.py
+++ b/examples/int8_inference_huggingface.py
@@ -2,23 +2,18 @@
 from transformers import LlamaForCausalLM, LlamaTokenizer
 
 MAX_NEW_TOKENS = 128
-model_name = 'meta-llama/Llama-2-7b-hf'
+model_name = "meta-llama/Llama-2-7b-hf"
 
-text = 'Hamburg is in which country?\n'
+text = "Hamburg is in which country?\n"
 tokenizer = LlamaTokenizer.from_pretrained(model_name)
 input_ids = tokenizer(text, return_tensors="pt").input_ids
 
-max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
+max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB"
 
 n_gpus = torch.cuda.device_count()
 max_memory = {i: max_memory for i in range(n_gpus)}
 
-model = LlamaForCausalLM.from_pretrained(
-  model_name,
-  device_map='auto',
-  load_in_8bit=True,
-  max_memory=max_memory
-)
+model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory)
 
 generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
 print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
diff --git a/install_cuda.py b/install_cuda.py
index b41b33b39..9e426cbd7 100644
--- a/install_cuda.py
+++ b/install_cuda.py
@@ -19,6 +19,7 @@
     "123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run",
 }
 
+
 def install_cuda(version, base_path, download_path):
     formatted_version = f"{version[:-1]}.{version[-1]}"
     folder = f"cuda-{formatted_version}"
@@ -29,7 +30,7 @@ def install_cuda(version, base_path, download_path):
         subprocess.run(["rm", "-rf", install_path], check=True)
 
     url = cuda_versions[version]
-    filename = url.split('/')[-1]
+    filename = url.split("/")[-1]
     filepath = os.path.join(download_path, filename)
 
     if not os.path.exists(filepath):
@@ -44,9 +45,14 @@ def install_cuda(version, base_path, download_path):
     # Install CUDA
     print(f"Installing CUDA version {version}...")
     install_command = [
-        "bash", filepath,
-        "--no-drm", "--no-man-page", "--override",
-        "--toolkitpath=" + install_path, "--toolkit", "--silent"
+        "bash",
+        filepath,
+        "--no-drm",
+        "--no-man-page",
+        "--override",
+        "--toolkitpath=" + install_path,
+        "--toolkit",
+        "--silent",
     ]
 
     print(f"Running command: {' '.join(install_command)}")
@@ -62,6 +68,7 @@ def install_cuda(version, base_path, download_path):
 
     print(f"CUDA version {version} installed at {install_path}")
 
+
 def main():
     user_base_path = os.path.expanduser("~/cuda")
     system_base_path = "/usr/local/cuda"
@@ -93,5 +100,6 @@ def main():
         print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}")
         sys.exit(1)
 
+
 if __name__ == "__main__":
     main()
diff --git a/pyproject.toml b/pyproject.toml
index f74750720..609ff84fa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,6 +8,10 @@ src = [
     "tests",
     "benchmarking"
 ]
+target-version = "py38"
+line-length = 119
+
+[tool.ruff.lint]
 select = [
     "B",    # bugbear: security warnings
     "E",    # pycodestyle
@@ -17,7 +21,6 @@ select = [
     "UP",   # alert you when better syntax is available in your python version
     "RUF",  # the ruff developer's own rules
 ]
-target-version = "py38"
 ignore = [
     "B007",  # Loop control variable not used within the loop body (TODO: enable)
     "B028",  # Warning without stacklevel (TODO: enable)
@@ -30,7 +33,7 @@ ignore = [
 ]
 ignore-init-module-imports = true  # allow to expose in __init__.py via imports
 
-[tool.ruff.extend-per-file-ignores]
+[tool.ruff.lint.extend-per-file-ignores]
 "**/__init__.py" = ["F401"]  # allow unused imports in __init__.py
 "{benchmarking,tests}/**/*.py" = [
     "B007",
@@ -42,7 +45,7 @@ ignore-init-module-imports = true  # allow to expose in __init__.py via imports
     "UP030",
 ]
 
-[tool.ruff.isort]
+[tool.ruff.lint.isort]
 combine-as-imports = true
 detect-same-package = true
 force-sort-within-sections = true
diff --git a/scripts/stale.py b/scripts/stale.py
index 613f5b7cb..a65652aeb 100644
--- a/scripts/stale.py
+++ b/scripts/stale.py
@@ -15,6 +15,7 @@
 Script to close stale issue. Taken in part from the AllenNLP repository.
 https://github.com/allenai/allennlp.
 """
+
 from datetime import datetime as dt, timezone
 import os
 
@@ -50,7 +51,7 @@ def main():
             issue.create_comment(
                 "This issue has been automatically marked as stale because it has not had "
                 "recent activity. If you think this still needs to be addressed "
-                "please comment on this thread.\n\n"
+                "please comment on this thread.\n\n",
             )
 
 
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index d01e5e9db..9da665a2d 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -20,7 +20,11 @@
 @pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
 @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
 @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
-@pytest.mark.parametrize("funcs", [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], ids=["func=bmm", "func=matmul"])
+@pytest.mark.parametrize(
+    "funcs",
+    [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
+    ids=["func=bmm", "func=matmul"],
+)
 @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
 @pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
 @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
@@ -30,16 +34,13 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
     dim3 = dim3 - (dim3 % 16)
     dim4 = dim4 - (dim4 % 16)
     for i in range(25):
-
         # normal multiply
         if funcs[0] in [torch.mm, torch.matmul]:
             dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
             dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
             A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
             B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
-            target = torch.randn(
-                size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
-            )
+            target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
             torch.nn.init.xavier_uniform_(B)
 
             if not transpose[0] and not transpose[1]:
@@ -71,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
                 A.grad = None
                 B.grad = None
 
-                loss_torch = torch.nn.functional.mse_loss(
-                    out_torch, target
-                ).mean()
+                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
                 loss_torch.backward()
                 gradA2 = A.grad
                 gradB2 = B.grad
@@ -81,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
                 B.grad = None
 
             if req_grad[0]:
-                torch.testing.assert_close(
-                    gradA1, gradA2, atol=0.015, rtol=0.1
-                )
+                torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
             if req_grad[1]:
                 n = gradB1.numel()
                 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
                 assert (idx == 0).sum().item() < n * 0.1
                 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
                 assert (idx == 0).sum().item() < n * 0.02
-                torch.testing.assert_close(
-                    gradB1, gradB2, atol=0.18, rtol=0.3
-                )
+                torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
 
         # batched matrix multiply
         if funcs[0] in [torch.bmm, torch.matmul]:
@@ -119,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
             n = out_bnb.numel()
             idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
             assert (idx == 0).sum().item() < n * 0.01
-            torch.testing.assert_close(
-                out_bnb, out_torch, atol=0.027, rtol=0.2
-            )
+            torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)
 
             if any(req_grad):
                 out_bnb.data.copy_(out_torch)
@@ -133,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
                 A.grad = None
                 B.grad = None
 
-                loss_torch = torch.nn.functional.mse_loss(
-                    out_torch, target
-                ).mean()
+                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
                 loss_torch.backward()
                 gradA2 = A.grad
                 gradB2 = B.grad
@@ -143,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
                 B.grad = None
 
             if req_grad[0]:
-                torch.testing.assert_close(
-                    gradA1, gradA2, atol=0.015, rtol=0.1
-                )
+                torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
             if req_grad[1]:
                 n = gradB1.numel()
                 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
@@ -192,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
                 A.grad = None
                 B.grad = None
 
-                loss_torch = torch.nn.functional.mse_loss(
-                    out_torch, target
-                ).mean()
+                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
                 loss_torch.backward()
                 gradA2 = A.grad
                 gradB2 = B.grad
@@ -202,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
                 B.grad = None
 
             if req_grad[0]:
-                torch.testing.assert_close(
-                    gradA1, gradA2, atol=0.015, rtol=0.1
-                )
+                torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
             if req_grad[1]:
                 n = gradB1.numel()
                 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
@@ -218,25 +203,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
 @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
 @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
 @pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp"))
-@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], ids=["func=matmul", "func=switchback_bnb"])
+@pytest.mark.parametrize(
+    "funcs",
+    [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)],
+    ids=["func=matmul", "func=switchback_bnb"],
+)
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
 @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
 @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
 @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
 @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
-def test_matmullt(
-    dim1,
-    dim2,
-    dim3,
-    dim4,
-    funcs,
-    dtype,
-    req_grad,
-    transpose,
-    decomp,
-    has_fp16_weights,
-    has_bias
-):
+def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias):
     dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
     dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
     outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
@@ -245,18 +222,13 @@ def test_matmullt(
         req_grad[2] = False
 
     for i in range(3):
-
         # normal multiply
         if funcs[0] in [torch.mm, torch.matmul]:
-            A = torch.randn(
-                size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
-            )
+            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
             if decomp == 6.0:
                 with torch.no_grad():
                     A[:, outlier_dim] = 6.0
-            B = torch.randn(
-                size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
-            )
+            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
             target = torch.randn(
                 size=(dim2, dim4),
                 device="cuda",
@@ -266,7 +238,7 @@ def test_matmullt(
             bias = None
             bias2 = None
             if has_bias:
-                bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
+                bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
                 bias2 = bias.clone()
             torch.nn.init.xavier_uniform_(B)
             B2 = B.clone()
@@ -311,9 +283,7 @@ def test_matmullt(
                 if any(req_grad):
                     out_bnb.data.copy_(out_torch)
                     torch.cuda.synchronize()
-                    loss_bnb = torch.nn.functional.mse_loss(
-                        out_bnb, target
-                    ).mean()
+                    loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                     loss_bnb.backward()
                     gradA1 = A.grad
                     gradB1 = B.grad
@@ -323,9 +293,7 @@ def test_matmullt(
                         gradBias1 = bias.grad
                         bias.grad = None
 
-                    loss_torch = torch.nn.functional.mse_loss(
-                        out_torch, target
-                    ).mean()
+                    loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
                     loss_torch.backward()
                     gradA2 = A.grad
                     gradB2 = B.grad
@@ -336,9 +304,7 @@ def test_matmullt(
                         bias.grad = None
 
                 if req_grad[0]:
-                    torch.testing.assert_close(
-                        gradA1, gradA2, atol=0.015, rtol=0.1
-                    )
+                    torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
                 if req_grad[1]:
                     n = gradB1.numel()
                     if dim2 > 0:
@@ -352,9 +318,7 @@ def test_matmullt(
                     assert (idx == 0).sum().item() <= n * 0.1
                     idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
                     assert (idx == 0).sum().item() <= n * 0.02
-                    torch.testing.assert_close(
-                        gradB1, gradB2, atol=0.18, rtol=0.3
-                    )
+                    torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
 
                 if req_grad[2]:
                     torch.testing.assert_close(gradBias1, gradBias2)
@@ -370,8 +334,20 @@ def test_matmullt(
 @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
 @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
 @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
-@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'], ids=id_formatter("quant_type"))
-def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
+@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
+def test_matmul_4bit(
+    dim1,
+    dim2,
+    dim3,
+    dim4,
+    funcs,
+    dtype,
+    req_grad,
+    transpose,
+    has_bias,
+    compress_statistics,
+    quant_type,
+):
     dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
     dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
     if has_bias == False:
@@ -387,11 +363,15 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
             bias = None
             bias2 = None
             if has_bias:
-                bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
+                bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
                 bias2 = bias.clone()
             torch.nn.init.xavier_uniform_(B)
 
-            B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type)
+            B2, quant_state = bnb.functional.quantize_4bit(
+                B,
+                compress_statistics=compress_statistics,
+                quant_type=quant_type,
+            )
 
             if not transpose[0] and transpose[1]:
                 out_torch = funcs[0](A, B.t())
@@ -410,7 +390,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
             if n > 0:
                 assert err < 0.115
 
-                #assert err < 0.20
+                # assert err < 0.20
             if any(req_grad):
                 out_bnb.data.copy_(out_torch)
                 torch.cuda.synchronize()
@@ -424,7 +404,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
                     gradBias1 = bias.grad
                     bias.grad = None
 
-                loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
+                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
                 loss_torch.backward()
                 gradA2 = A.grad
                 gradB2 = B.grad
@@ -435,7 +415,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
                     bias.grad = None
 
                 if req_grad[0]:
-                    torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
+                    torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
 
                 if req_grad[2]:
                     torch.testing.assert_close(gradBias1, gradBias2)
@@ -448,8 +428,12 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
 @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
 @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
 @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
-@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], ids=["matmul_fp8_mixed", 'matmul_fp8_global'])
-def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
+@pytest.mark.parametrize(
+    "funcs",
+    [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)],
+    ids=["matmul_fp8_mixed", "matmul_fp8_global"],
+)
+def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
     dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
     dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
     req_grad = list(req_grad)
@@ -480,7 +464,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
             err = torch.abs(out_bnb - out_torch).float().mean().item()
             if n > 0:
                 assert err < 0.115
-                #assert err < 0.20
+                # assert err < 0.20
             if any(req_grad):
                 out_bnb.data.copy_(out_torch)
                 torch.cuda.synchronize()
@@ -491,7 +475,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
                 A.grad = None
                 B.grad = None
 
-                loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
+                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
                 loss_torch.backward()
                 gradA2 = A.grad
                 gradB2 = B.grad
@@ -499,7 +483,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
                 B.grad = None
 
                 if req_grad[0]:
-                    torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
+                    torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
 
                 if req_grad[1]:
                     n = gradB1.numel()
@@ -514,8 +498,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
                     assert (idx == 0).sum().item() <= n * 0.1
                     idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
                     assert (idx == 0).sum().item() <= n * 0.02
-                    grad_err = (gradB1-gradB2).abs().mean()
+                    grad_err = (gradB1 - gradB2).abs().mean()
                     assert grad_err.item() < 0.003
-                    torch.testing.assert_close(
-                        gradB1, gradB2, atol=0.18, rtol=0.3
-                    )
+                    torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py
index cb0b38fdd..fc79a54b0 100644
--- a/tests/test_cuda_setup_evaluator.py
+++ b/tests/test_cuda_setup_evaluator.py
@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
 
 def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
     monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
-    assert (
-        get_cuda_bnb_library_path(cuda111_noblas_spec).stem
-        == "libbitsandbytes_cuda111_nocublaslt"
-    )
+    assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
diff --git a/tests/test_functional.py b/tests/test_functional.py
index d4f65755f..b9f1a6ead 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -19,9 +19,7 @@
     id_formatter,
 )
 
-torch.set_printoptions(
-    precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
-)
+torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
 k = 20
 
 
@@ -98,9 +96,7 @@ def teardown():
     pass
 
 
-@pytest.mark.parametrize(
-    "dtype", [torch.float32, torch.float16], ids=["float", "half"]
-)
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
 def test_estimate_quantiles(dtype):
     A = torch.rand(1024, 1024, device="cuda")
     A = A.to(dtype)
@@ -136,7 +132,6 @@ def test_quantile_quantization():
         assert diff < 0.001
 
 
-
 def test_dynamic_quantization():
     diffs = []
     reldiffs = []
@@ -149,8 +144,8 @@ def test_dynamic_quantization():
         diffs.append(diff.mean().item())
         reldiffs.append(reldiff.mean().item())
         assert diff.mean().item() < 0.0135
-    print(sum(diffs)/len(diffs))
-    print(sum(reldiffs)/len(reldiffs))
+    print(sum(diffs) / len(diffs))
+    print(sum(reldiffs) / len(reldiffs))
 
     for i in range(100):
         A1 = torch.rand(1024, 1024, device="cuda")
@@ -161,13 +156,12 @@ def test_dynamic_quantization():
         assert diff < 0.004
 
 
-
 @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
 @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
 @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
 @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
 def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
-    #print('')
+    # print('')
     diffs = []
     reldiffs = []
     for i in range(100):
@@ -178,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
         reldiff = diff / torch.abs(A1.float() + 1e-8)
         diffs.append(diff.mean().item())
         reldiffs.append(reldiff.mean().item())
-    abserr = sum(diffs)/len(diffs)
-    relerr = sum(reldiffs)/len(reldiffs)
-    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
-    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
+    abserr = sum(diffs) / len(diffs)
+    relerr = sum(reldiffs) / len(reldiffs)
+    # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
+    # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
     assert abserr < 0.011
     assert relerr < 0.018
     assert A2.dtype == dtype
@@ -196,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
         reldiff = diff / torch.abs(A1.float() + 1e-8)
         diffs.append(diff.mean().item())
         reldiffs.append(reldiff.mean().item())
-        #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
-    abserr = sum(diffs)/len(diffs)
-    relerr = sum(reldiffs)/len(reldiffs)
+        # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
+    abserr = sum(diffs) / len(diffs)
+    relerr = sum(reldiffs) / len(reldiffs)
     if signed:
         assert abserr < 0.0035
         assert relerr < 0.015
@@ -206,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
         assert abserr < 0.00175
         assert relerr < 0.012
     assert A2.dtype == dtype
-    #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
-    #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
-
+    # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
+    # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
 
 
-@pytest.mark.parametrize(
-    "gtype", [torch.float32, torch.float16], ids=["float", "half"]
-)
+@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
 def test_percentile_clipping(gtype):
     gnorm_vec1 = torch.zeros(100, device="cuda")
     gnorm_vec2 = torch.zeros(100, device="cuda")
@@ -223,9 +214,7 @@ def test_percentile_clipping(gtype):
     for i in range(k):
         step += 1
         g = torch.randn(n, n, dtype=gtype, device="cuda")
-        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
-            g, gnorm_vec2, step, percentile=percentile
-        )
+        gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
         assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
 
         gnorm2 = torch.norm(g.float())
@@ -309,7 +298,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
     dim2 = dim2 - (dim2 % 32)
     errors = []
     relerrors = []
-    #print("")
+    # print("")
     for i in range(5):
         if batched:
             A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
@@ -321,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
             B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
             maxA, Ac = quant_methods[0](A, 1)
             maxB, Bc = quant_methods[1](B, 0)
-        torch.testing.assert_close(
-            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
-        )
+        torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
         if batched:
             out2 = torch.bmm(A, B)
             C = torch.bmm(Ac.float(), Bc.float())
@@ -338,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
         relerr = err / torch.abs(out2)
         errors.append(err.mean().item())
         relerrors.append(relerr.mean().item())
-    #print(mean(errors))
-    #print(mean(relerrors))
+    # print(mean(errors))
+    # print(mean(relerrors))
 
 
 def test_stable_embedding():
@@ -356,16 +343,8 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
     batch_dim = batch_dim - (batch_dim % 16)
     seq_dim = seq_dim - (seq_dim % 16)
     for i in range(k):
-        shapeA = (
-            (batch_dim, hidden_dim)
-            if not transpose[0]
-            else (hidden_dim, batch_dim)
-        )
-        shapeB = (
-            (32 * random.randint(1, 4), hidden_dim)
-            if transpose[1]
-            else (hidden_dim, 32 * random.randint(1, 4))
-        )
+        shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
+        shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
         A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
         B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
         if not transpose[0] and not transpose[1]:
@@ -385,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
 
     for i in range(k):
         shapeA = (batch_dim, seq_dim, hidden_dim)
-        shapeB = (
-            (32 * random.randint(1, 4), hidden_dim)
-            if transpose[1]
-            else (hidden_dim, 32 * random.randint(1, 4))
-        )
+        shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
         A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
         B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
         if not transpose[0] and not transpose[1]:
@@ -410,16 +385,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
     hidden_dim = hidden_dim - (hidden_dim % 32)
     batch_dim = batch_dim - (batch_dim % 2)
     for i in range(25):
-        A = torch.randint(
-            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
-        ).to(torch.int8)
-        B = torch.randint(
-            -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
-        ).to(torch.int8)
+        A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
+        B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
         out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
-        iout = torch.empty(
-            A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
-        )
+        iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
         out = F.igemm(A, B, out=iout)
 
         torch.testing.assert_close(out.float(), out2)
@@ -444,9 +413,7 @@ def min_max(x):
     errs2 = []
     relerrs2 = []
     for i in range(k):
-        A = torch.normal(
-            0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
-        )
+        A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
         if transpose:
             B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
         else:
@@ -523,9 +490,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
             out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
             out = F.igemm(A.permute([0, 2, 1]), B)
         elif transpose[0] and transpose[1]:
-            out2 = torch.bmm(
-                A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
-            )
+            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
             out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
         torch.testing.assert_close(out.float(), out2.float())
 
@@ -541,7 +506,7 @@ def test_vector_quant(dim1, dim2, dim3):
         qA, SA = F.vectorwise_quant(A, dim=0)
         A1 = F.vectorwise_dequant(qA, SA)
         n = A1.numel()
-        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))
+        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))
 
 
 @pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@@ -565,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
     if dims == 2:
         A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
     elif dims == 3:
-        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
-            dtype
-        )
+        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
 
     out, S = F.nvidia_transform(A, to_order=orderOut)
 
@@ -579,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
         if dims == 2:
             n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
         elif dims == 3:
-            n = (
-                A.shape[0]
-                * A.shape[1]
-                * (A.shape[2] + (32 - (A.shape[2] % 32)))
-            )
+            n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
         assert out.numel() == n
     elif orderOut == "col_turing":
         # 32 col 8 row tiles
-        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
-            A.shape[1] + (32 - (A.shape[1] % 32))
-        )
+        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32)))
         assert out.numel() == n
         total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
         for row in range(A.shape[0]):
@@ -598,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
                 j = col
 
                 coltile = (col // 32) + (1 if col % 32 != 0 else 0)
-                rowtile = (
-                    (row // 8) + (1 if row % 8 != 0 else 0)
-                ) * total_coltile
+                rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
                 offset = 32 * 8 * (rowtile + coltile)
                 col2 = col % 32
                 row2 = (row % 8) * 32
@@ -611,9 +566,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
                 # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
 
     if orderOut == "col32":
-        out2, S = F.nvidia_transform(
-            out, from_order=orderOut, to_order="row", state=S
-        )
+        out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
         torch.testing.assert_close(A, out2)
 
 
@@ -626,16 +579,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
 def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
     for i in range(k):
         if dims == 2:
-            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
-                torch.int8
-            )
+            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
         elif dims == 3:
-            A = torch.randint(
-                -128, 127, size=(dim1, dim2, dim3), device="cuda"
-            ).to(torch.int8)
-        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
-            torch.int8
-        )
+            A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
+        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
         C1 = torch.matmul(A.float(), B.t().float())
 
         A2, SA = F.transform(A, "col32")
@@ -645,9 +592,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
         torch.testing.assert_close(C1, C3.float())
 
         # transpose
-        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
-            torch.int8
-        )
+        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8)
         C1 = torch.matmul(A.float(), B.float())
 
         B2t, SBt = F.transform(B, "col_turing", transpose=True)
@@ -667,9 +612,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
         if dims == 2:
             A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
         elif dims == 3:
-            A = torch.normal(
-                0, 0.5, size=(dim1, dim2, dim3), device="cuda"
-            ).half()
+            A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
         B = torch.randn((dim4, dim3), device="cuda").half()
         torch.nn.init.xavier_uniform_(B)
         C1 = torch.matmul(A, B.t())
@@ -700,6 +643,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
         # C3, S = F.transform(C2, 'row', state=SC)
         # torch.testing.assert_close(C1, C3.float())
 
+
 @pytest.mark.parametrize(
     ("batch", "seq", "model", "hidden"),
     [
@@ -729,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden):
     torch.cuda.synchronize()
     t0 = time.time()
     for i in range(k):
-
         out1 = torch.matmul(A, w1.t())  # fc1
         # out2 = torch.matmul(out1, w2.t())# fc2
 
@@ -866,13 +809,15 @@ def test_bench_8bit_training(batch, seq, model, hidden):
 def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
     inner = torch.randint(1, 128, size=(1,)).item()
     bias = None
-    if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
+    if has_bias:
+        bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
     formatB = F.get_special_format_str()
     for i in range(1):
         A = torch.randn(dim1, inner, device="cuda")
         B = torch.randn(dim4, inner, device="cuda")
         C1 = torch.matmul(A.half(), B.t().half())
-        if has_bias: C1 += bias
+        if has_bias:
+            C1 += bias
 
         A1, maxA = F.vectorwise_quant(A, dim=1)
         B1, maxB = F.vectorwise_quant(B, dim=1)
@@ -883,7 +828,8 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
 
         C3, S = F.nvidia_transform(C2, "row", state=SC)
         C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
-        if has_bias: C4 += bias
+        if has_bias:
+            C4 += bias
 
         # TODO: is something wrong here? If so, the problem goes deeper
         # n = C1.numel()
@@ -917,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims):
         else:
             assert False
 
-        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
-            A, threshold=threshold
-        )
+        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
 
         A_blocked = einops.rearrange(
             torch.abs(A),
@@ -939,9 +883,7 @@ def test_colrow_absmax(dim1, dim2, dims):
         torch.testing.assert_close(row_stats1_trunc, row_stats2)
         torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
 
-        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
-            A, threshold=0.0
-        )
+        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
 
         torch.testing.assert_close(col_stats1, col_stats2)
         torch.testing.assert_close(row_stats1, row_stats2)
@@ -963,24 +905,16 @@ def test_double_quant(dim1, dim2):
         torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
 
         n = CAt.numel()
-        num_not_close_rows = (
-            (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
-        )
-        num_not_close_cols = (
-            (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
-        )
+        num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
+        num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
 
         # allow for 1:500 error due to rounding differences
         min_error = 1 / 500
         if num_not_close_cols > (min_error * n):
-            print(
-                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
-            )
+            print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}")
             assert False
         if num_not_close_rows > (min_error * n):
-            print(
-                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
-            )
+            print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}")
             assert False
 
         torch.testing.assert_close(Srow.flatten().float(), statsA)
@@ -991,13 +925,12 @@ def test_double_quant(dim1, dim2):
     ("dim1", "dim4", "inner"),
     (
         pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
-        for (dim1, dim4, inner)
-        in zip(
+        for (dim1, dim4, inner) in zip(
             get_test_dims(1, 4 * 1024, n=4),
             get_test_dims(1, 4 * 1024, n=4),
             get_test_dims(1, 4 * 1024, n=4),
         )
-    )
+    ),
 )
 def test_integrated_igemmlt(dim1, dim4, inner):
     for i in range(k):
@@ -1037,13 +970,12 @@ def test_integrated_igemmlt(dim1, dim4, inner):
     ("dim1", "dim4", "inner"),
     (
         pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
-        for (dim1, dim4, inner)
-        in zip(
+        for (dim1, dim4, inner) in zip(
             get_test_dims(1, 4 * 1024, n=6),
             get_test_dims(1, 4 * 1024, n=6),
             get_test_dims(1, 4 * 1024, n=6),
         )
-    )
+    ),
 )
 @pytest.mark.skip("Row scale has some bugs for ampere")
 def test_igemmlt_row_scale(dim1, dim4, inner):
@@ -1067,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
 
         c = 10.0 * inner * scale
         row_scale = torch.ones_like(maxA) / c
-        outC32, SC = F.igemmlt(
-            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
-        )
+        outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
         C3, S = F.nvidia_transform(outC32, "row", state=SC)
         maxval = torch.abs(C3).max()
         if maxval == 127:
@@ -1150,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner):
     torch.cuda.synchronize()
     t0 = time.time()
     for i in range(k):
-        outC32, SC = F.igemmlt(
-            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
-        )
+        outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
     torch.cuda.synchronize()
     print("row-wise", time.time() - t0)
 
@@ -1177,13 +1105,9 @@ def test_row_scale_bench(dim1, dim4, inner):
 def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
     for i in range(k):
         if dims == 2:
-            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
-                dtype
-            )
+            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
         elif dims == 3:
-            A = torch.randint(
-                10, 99, size=(dim1, dim2, dim3), device="cuda"
-            ).to(dtype)
+            A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
 
         A.view(-1)[-1] = -1
         if transpose:
@@ -1224,23 +1148,17 @@ def test_coo_double_quant(dim1, dim2):
 
         idx = torch.abs(A) >= threshold
         CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
-        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
-            A, threshold=threshold
-        )
+        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
 
         if coo_tensor is not None:
             A1 = A * idx
             A2 = torch.zeros_like(A)
-            A2[
-                coo_tensor.rowidx.long(), coo_tensor.colidx.long()
-            ] = coo_tensor.values
+            A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
             torch.testing.assert_close(A1, A2)
 
             A1 = A * (idx == 0)
             A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
-            torch.testing.assert_close(
-                A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
-            )
+            torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
 
 
 @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
@@ -1261,9 +1179,7 @@ def test_spmm_coo(dim1, dim2, transposed_B):
         nnz = (idx == 1).sum().item()
         rows, cols = torch.where(idx)
         values = A[idx]
-        cooA = F.COOSparseTensor(
-            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
-        )
+        cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
         A2 = A * idx
 
         if transposed_B:
@@ -1303,9 +1219,7 @@ def test_spmm_bench():
     print(nnz / idx.numel())
     rows, cols = torch.where(idx)
     values = A[idx]
-    cooA = F.COOSparseTensor(
-        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
-    )
+    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
 
     for i in range(10):
         out2 = F.spmm_coo(cooA, B)
@@ -1339,9 +1253,7 @@ def test_integrated_sparse_decomp(dim1, dim2):
         out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
         out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
 
-        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
-            A, threshold=threshold
-        )
+        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
         C32A, SA = F.transform(CA, "col32")
 
         out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
@@ -1396,9 +1308,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
     nnz = (idx == 1).sum().item()
     rows, cols = torch.where(idx)
     values = A[idx]
-    cooA = F.COOSparseTensor(
-        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
-    )
+    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
     A2 = A * idx
     out1 = torch.matmul(A2.half(), B.half())
     out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
@@ -1413,9 +1323,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
     std = out1.std()
     out1 /= std
     out2 /= std
-    assert_all_approx_close(
-        out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
-    )
+    assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
     # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
 
     idx_col = torch.randint(0, A2.shape[-1], size=(15,))
@@ -1443,9 +1351,7 @@ def test_coo2csr():
     nnz = (idx == 1).sum().item()
     rows, cols = torch.where(idx)
     values = A[idx]
-    cooA = F.COOSparseTensor(
-        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
-    )
+    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
     A2 = A * idx
     csrA = F.coo2csr(cooA)
     counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
@@ -1463,9 +1369,7 @@ def test_coo2csc():
     nnz = (idx == 1).sum().item()
     rows, cols = torch.where(idx)
     values = A[idx]
-    cooA = F.COOSparseTensor(
-        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
-    )
+    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
     A2 = A * idx
     cscA = F.coo2csc(cooA)
     counts = cscA.colptr[1:] - cscA.colptr[:-1]
@@ -1499,9 +1403,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
     nnz = (idx == 1).sum().item()
     rows, cols = torch.where(idx)
     values = A[idx]
-    cooA = F.COOSparseTensor(
-        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
-    )
+    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
     A2 = A * idx
     out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
     out1 = torch.matmul(A2, B.half())
@@ -1582,7 +1484,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
 
 @pytest.mark.parametrize(
     ("batch", "seq", "model", "hidden"),
-    [pytest.param(1, 1, 6656, 4*6656, id="batch=1, seq=1, model=6656, hidden=26k")],
+    [pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")],
 )
 @pytest.mark.benchmark
 def test_bench_matmul(batch, seq, model, hidden):
@@ -1605,8 +1507,8 @@ def test_bench_matmul(batch, seq, model, hidden):
     outliers = torch.randint(0, model, size=(5,)).cuda()
     A[:, :, outliers] = 8.0
 
-    linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
-    #linearMixedBit.eval()
+    linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()
+    # linearMixedBit.eval()
 
     linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
     linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
@@ -1623,121 +1525,123 @@ def test_bench_matmul(batch, seq, model, hidden):
     for i in range(iters):
         torch.matmul(A, B.t())
     torch.cuda.synchronize()
-    print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
+    print(
+        f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s",
+    )
 
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
-    #torch.cuda.synchronize()
-    #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
+    # torch.cuda.synchronize()
+    # print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
 
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
-    #torch.cuda.synchronize()
-    #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
+    # torch.cuda.synchronize()
+    # print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
 
     torch.cuda.synchronize()
     t0 = time.time()
     for i in range(iters):
         bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
     torch.cuda.synchronize()
-    print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
+    print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
 
     torch.cuda.synchronize()
     t0 = time.time()
     for i in range(iters):
         bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
     torch.cuda.synchronize()
-    print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
+    print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
 
-
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    bnb.matmul(A, B)
-    #torch.cuda.synchronize()
-    #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+    # torch.cuda.synchronize()
+    # print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
 
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    bnb.matmul(A, B, threshold=6.0)
-    #torch.cuda.synchronize()
-    #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
-
-    #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
-    #C32A, SA = F.transform(CA, "col32")
-    #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
-    #CxB, SB = F.transform(CB, to_order=formatB)
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # torch.cuda.synchronize()
+    # print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+
+    # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
+    # C32A, SA = F.transform(CA, "col32")
+    # CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
+    # CxB, SB = F.transform(CB, to_order=formatB)
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
-    #torch.cuda.synchronize()
-    #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
-
-    #BA, statsB = F.vectorwise_quant(B, dim=1)
-    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # torch.cuda.synchronize()
+    # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+
+    # BA, statsB = F.vectorwise_quant(B, dim=1)
+    # CxB, SB = F.nvidia_transform(CB, to_order=formatB)
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    A2 = A.view(-1, A.shape[-1]).contiguous()
     #    CA, statsA = F.vectorwise_quant(A2, dim=1)
     #    C32A, SA = F.nvidia_transform(CA, "col32")
     #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
     #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
     #    F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
-    #torch.cuda.synchronize()
-    #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
-
-    #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
-    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # torch.cuda.synchronize()
+    # print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+
+    # BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
+    # CxB, SB = F.nvidia_transform(CB, to_order=formatB)
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    A2 = A.view(-1, A.shape[-1]).contiguous()
     #    CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
     #    C32A, SA = F.nvidia_transform(CA, "col32")
     #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
     #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
     #    out = Cout * statsB * statsA * (1.0 / (127 * 127))
-    #torch.cuda.synchronize()
-    #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+    # torch.cuda.synchronize()
+    # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
 
-    #linear8bit(A)
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # linear8bit(A)
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    linear8bit(A)
-    #torch.cuda.synchronize()
-    #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+    # torch.cuda.synchronize()
+    # print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
 
-    #linearMixedBit(A)
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # linearMixedBit(A)
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    linearMixedBit(A)
-    #torch.cuda.synchronize()
-    #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+    # torch.cuda.synchronize()
+    # print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
 
-    #linear8bit_train(A)
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # linear8bit_train(A)
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    linear8bit_train(A)
-    #torch.cuda.synchronize()
-    #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+    # torch.cuda.synchronize()
+    # print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
 
-    #linear8bit_train_thresh(A)
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # linear8bit_train_thresh(A)
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    linear8bit_train(A)
-    #torch.cuda.synchronize()
-    #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+    # torch.cuda.synchronize()
+    # print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
+
 
 def test_zeropoint():
     def quant_zp(x):
@@ -1778,8 +1682,8 @@ def quant_zp(x):
     C2 -= A.sum(1).view(-1, 1) * zp
 
     ca, cqa, cza = quant_zp(A)
-    #print(ca.min(), ca.max())
-    #print((ca - cza).min(), (ca - cza).max())
+    # print(ca.min(), ca.max())
+    # print((ca - cza).min(), (ca - cza).max())
 
     zp = 1
     scale = 2.0
@@ -1808,14 +1712,14 @@ def quant_zp(x):
     C7 -= zpa * zpb * A.shape[1]
     C7 /= qa * qb
 
-    #print("")
+    # print("")
     # print(C0.flatten()[:10])
-    #print(C1.flatten()[:10])
-    #print(C2.flatten()[:10])
-    #print(C3.flatten()[:10])
-    #print(C5.flatten()[:10])
-    #print(C6.flatten()[:10])
-    #print(C7.flatten()[:10])
+    # print(C1.flatten()[:10])
+    # print(C2.flatten()[:10])
+    # print(C3.flatten()[:10])
+    # print(C5.flatten()[:10])
+    # print(C6.flatten()[:10])
+    # print(C7.flatten()[:10])
     err1 = torch.abs(C1 - C2).mean().item()
     err2 = torch.abs(C1 - C3).mean().item()
     err3 = torch.abs(C1 - C4).mean().item()
@@ -1852,16 +1756,15 @@ def test_extract_outliers():
         torch.testing.assert_close(outliers1, outliers2)
 
 
-
 def test_blockwise_cpu_large():
     diffs = []
     reldiffs = []
     batch = 128
     seq = 128
-    for hidden in [128]:#, 14336]:
+    for hidden in [128]:  # , 14336]:
         for blocksize in [4096, 16384]:
             for i in range(2):
-                A1 = torch.randn(batch, seq, hidden, device='cpu')
+                A1 = torch.randn(batch, seq, hidden, device="cpu")
                 t0 = time.time()
                 C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                 A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
@@ -1875,10 +1778,9 @@ def test_blockwise_cpu_large():
             # print(sum(reldiffs)/len(reldiffs))
 
 
-
 def test_fp8_quant():
     for e_bits in range(1, 7):
-        p_bits = 7-e_bits
+        p_bits = 7 - e_bits
         code = F.create_fp8_map(True, e_bits, p_bits).cuda()
 
         abserr = []
@@ -1888,12 +1790,12 @@ def test_fp8_quant():
             C, SC = F.quantize_blockwise(A1, code=code)
             A2 = F.dequantize_blockwise(C, SC)
             diff = torch.abs(A1 - A2)
-            reldiff = diff/torch.abs(A1+1e-8)
+            reldiff = diff / torch.abs(A1 + 1e-8)
             abserr.append(diff.mean().item())
             relerr.append(reldiff.mean().item())
-            #assert diff < 0.0075
-        #print(sum(abserr)/len(abserr))
-        #print(sum(relerr)/len(relerr))
+            # assert diff < 0.0075
+        # print(sum(abserr)/len(abserr))
+        # print(sum(relerr)/len(relerr))
 
         abserr = []
         relerr = []
@@ -1902,12 +1804,12 @@ def test_fp8_quant():
             C, SC = F.quantize_blockwise(A1, code=code)
             A2 = F.dequantize_blockwise(C, SC)
             diff = torch.abs(A1 - A2)
-            reldiff = diff/torch.abs(A1+1e-8)
+            reldiff = diff / torch.abs(A1 + 1e-8)
             abserr.append(diff.mean().item())
             relerr.append(reldiff.mean().item())
-            #assert diff < 0.0075
-        #print(sum(abserr)/len(abserr))
-        #print(sum(relerr)/len(relerr))
+            # assert diff < 0.0075
+        # print(sum(abserr)/len(abserr))
+        # print(sum(relerr)/len(relerr))
 
         abserr = []
         relerr = []
@@ -1916,50 +1818,48 @@ def test_fp8_quant():
             C, SC = F.quantize_blockwise(A1)
             A2 = F.dequantize_blockwise(C, SC)
             diff = torch.abs(A1 - A2)
-            reldiff = diff/torch.abs(A1+1e-8)
+            reldiff = diff / torch.abs(A1 + 1e-8)
             abserr.append(diff.mean().item())
             relerr.append(reldiff.mean().item())
-            #assert diff < 0.0075
-        #print(3, sum(abserr)/len(abserr))
-        #print(3, sum(relerr)/len(relerr))
+            # assert diff < 0.0075
+        # print(3, sum(abserr)/len(abserr))
+        # print(3, sum(relerr)/len(relerr))
 
 
 def test_few_bit_quant():
-
-    #print('')
+    # print('')
     for bits in range(2, 9):
-        #print('='*30, bits, '='*30)
-        for method in ['linear', 'fp8', 'dynamic', 'quantile']:
+        # print('='*30, bits, '='*30)
+        for method in ["linear", "fp8", "dynamic", "quantile"]:
             abserrs = []
             relerrs = []
             code = None
-            if method == 'linear':
+            if method == "linear":
                 code = F.create_linear_map(True, total_bits=bits).cuda()
-            elif method == 'fp8':
-                ebits = math.ceil(bits/2)
-                pbits = bits-ebits-1
+            elif method == "fp8":
+                ebits = math.ceil(bits / 2)
+                pbits = bits - ebits - 1
                 code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
-            elif method == 'dynamic':
-                code = F.create_dynamic_map(True, bits-0, bits).cuda()
-            elif method == 'quantile':
-                values = torch.randn(2048, 2048, device='cuda')
+            elif method == "dynamic":
+                code = F.create_dynamic_map(True, bits - 0, bits).cuda()
+            elif method == "quantile":
+                values = torch.randn(2048, 2048, device="cuda")
                 code = F.create_quantile_map(values, bits).cuda()
             # for some data types we have no zero
             # for some data types we have one zero
             # for some data types we have two zeros
-            assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
-            #print(method, (code==0).sum())
+            assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
+            # print(method, (code==0).sum())
             assert code.numel() == 256
             for i in range(10):
-
-                values = torch.randn(1, 32, device='cuda')
+                values = torch.randn(1, 32, device="cuda")
                 values /= values.abs().max()
-                #values[values.abs() < 1e-6] += 1e-5
+                # values[values.abs() < 1e-6] += 1e-5
 
                 q1 = []
                 v1 = []
                 for v in values[0]:
-                    idx = torch.abs(v-code).argmin()
+                    idx = torch.abs(v - code).argmin()
                     q1.append(idx.item())
                     v1.append(code[idx].item())
 
@@ -1970,62 +1870,61 @@ def test_few_bit_quant():
                 v2 = F.dequantize_blockwise(q2, S2)
 
                 idx = torch.isclose(q1.int(), q2.int())
-                err2 = torch.abs(v2-values)
+                err2 = torch.abs(v2 - values)
                 abserrs.append(err2.mean().item())
-                relerrs.append((err2/(1e-10+values).abs()).mean().item())
+                relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
                 if idx.sum():
                     # some weird cases
-                    err1 = torch.abs(v1-values).mean()
-                    #assert err2.mean() <= err1
+                    err1 = torch.abs(v1 - values).mean()
+                    # assert err2.mean() <= err1
 
                 else:
                     torch.testing.assert_close(q1, q2)
-            #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
-    #assert False
+            # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
+    # assert False
 
 
 def test_kbit_quantile_estimation():
     for i in range(100):
-        data = torch.randn(1024, 1024, device='cuda')
+        data = torch.randn(1024, 1024, device="cuda")
         for bits in range(2, 9):
-            p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
+            p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
             val1 = torch.Tensor(norm.ppf(p)).cuda()
             val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
-            err = torch.abs(val1-val2).mean()
+            err = torch.abs(val1 - val2).mean()
             assert err < 0.038
 
     for i in range(100):
-        data = torch.randn(1024, 1024, device='cuda')
+        data = torch.randn(1024, 1024, device="cuda")
         for bits in range(2, 4):
-            total_values = 2**bits-1
-            p = np.linspace(0, 1, 2*total_values+1)
-            idx = np.arange(1, 2*total_values+1, 2)
+            total_values = 2**bits - 1
+            p = np.linspace(0, 1, 2 * total_values + 1)
+            idx = np.arange(1, 2 * total_values + 1, 2)
             p = p[idx]
-            offset = 1/(2*total_values)
-            p = np.linspace(offset, 1-offset, total_values)
+            offset = 1 / (2 * total_values)
+            p = np.linspace(offset, 1 - offset, total_values)
             val1 = torch.Tensor(norm.ppf(p)).cuda()
-            val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
-            err = torch.abs(val1-val2).mean()
+            val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
+            err = torch.abs(val1 - val2).mean()
             assert err < 0.035
 
 
 @pytest.mark.benchmark
 def test_bench_dequantization():
-    a = torch.rand(1024, 1024, device='cuda').half()
-    code =F.create_fp8_map(True, 3, 0, 4).cuda()
+    a = torch.rand(1024, 1024, device="cuda").half()
+    code = F.create_fp8_map(True, 3, 0, 4).cuda()
     qa, SA = F.quantize_blockwise(a, code=code)
     print(qa.max())
 
-    max_theoretical_mu =  1024*1024*2/1024**3/672*1000*1000
-    #print(max_theoretical_mu)
+    max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
+    # print(max_theoretical_mu)
 
     torch.cuda.synchronize()
     t0 = time.time()
     for i in range(100):
         qa, SA = F.quantize_blockwise(a)
     torch.cuda.synchronize()
-    #print((time.time()-t0)/1e6)
-
+    # print((time.time()-t0)/1e6)
 
 
 @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@@ -2037,26 +1936,28 @@ def test_fp4_quant(dtype):
         result = 0
         bias = 3
         sign, e1, e2, p1 = bits
-        idx = sign*8 + e1*4 + e2*2 + p1*1
+        idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1
         sign = -1.0 if sign else 1.0
-        exp = e1*2 + e2*1
+        exp = e1 * 2 + e2 * 1
         if exp == 0:
             # sub-normal
-            if p1 == 0: result = 0
-            else: result = sign*0.0625
+            if p1 == 0:
+                result = 0
+            else:
+                result = sign * 0.0625
         else:
             # normal
-            exp = 2**(-exp + bias + 1)
+            exp = 2 ** (-exp + bias + 1)
             frac = 1.5 if p1 else 1.0
-            result = sign*exp*frac
+            result = sign * exp * frac
         code[idx] = result
 
-    A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
+    A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
     qa, SA = F.quantize_fp4(A1, blocksize=64)
     A2 = F.dequantize_fp4(qa, SA)
 
     err = (A1 - A2).abs().float()
-    relerr = (err/(A1.abs().float()+1e-8)).mean()
+    relerr = (err / (A1.abs().float() + 1e-8)).mean()
     idx = err > 1.0
     err = err.mean()
 
@@ -2065,31 +1966,29 @@ def test_fp4_quant(dtype):
     assert relerr.item() < 0.28
 
 
-@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
+@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
 def test_4bit_compressed_stats(quant_type):
     for blocksize in [128, 64]:
         errs1 = []
         errs2 = []
         for i in range(10):
-            A1 = torch.randn(1024, 1024, device='cuda').half()
+            A1 = torch.randn(1024, 1024, device="cuda").half()
             q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
-            q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
+            q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
             A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
             A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
 
-
             err = (A1 - A2).abs().float()
-            relerr = (err/(A1.abs().float()+1e-15)).mean()
+            relerr = (err / (A1.abs().float() + 1e-15)).mean()
             err = err.mean()
 
             errs1.append(err.item())
 
-
             assert err.item() < 0.11
             assert relerr.item() < 0.28
 
             err = (A1 - A3).abs().float()
-            relerr = (err/(A1.abs().float()+1e-15)).mean()
+            relerr = (err / (A1.abs().float() + 1e-15)).mean()
             err = err.mean()
 
             errs2.append(err.item())
@@ -2097,70 +1996,71 @@ def test_4bit_compressed_stats(quant_type):
             assert err.item() < 0.11
             assert relerr.item() < 0.28
 
-        #print(sum(errs1)/len(errs1), blocksize, quant_type)
-        #print(sum(errs2)/len(errs2), blocksize, quant_type)
-
+        # print(sum(errs1)/len(errs1), blocksize, quant_type)
+        # print(sum(errs2)/len(errs2), blocksize, quant_type)
 
 
-
-#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
-@pytest.mark.parametrize("quant_type", ['nf4'])
+# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
+@pytest.mark.parametrize("quant_type", ["nf4"])
 @pytest.mark.benchmark
 def test_bench_4bit_dequant(quant_type):
     blocksize = 256
-    a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
+    a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
     qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
 
-    input_size = a.numel()/2
-    output_size = a.numel()*2
-    num_bytes = input_size+output_size
-    GB = num_bytes/1e9
-    max_theoretical_s =  GB/768
-    #print(max_theoretical_s*1e6)
-    b = torch.randn(128, 1024*12, device='cuda').half()
+    input_size = a.numel() / 2
+    output_size = a.numel() * 2
+    num_bytes = input_size + output_size
+    GB = num_bytes / 1e9
+    max_theoretical_s = GB / 768
+    # print(max_theoretical_s*1e6)
+    b = torch.randn(128, 1024 * 12, device="cuda").half()
 
     iters = 100
     torch.cuda.synchronize()
     t0 = time.time()
     for i in range(iters):
         F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
-        #b.copy_(a)
+        # b.copy_(a)
     torch.cuda.synchronize()
-    #print((time.time()-t0)/iters*1e6)
+    # print((time.time()-t0)/iters*1e6)
 
-    #torch.cuda.synchronize()
-    #t0 = time.time()
-    #for i in range(iters):
+    # torch.cuda.synchronize()
+    # t0 = time.time()
+    # for i in range(iters):
     #    torch.matmul(b, a.t())
-    #torch.cuda.synchronize()
-    #print((time.time()-t0)/iters*1e6)
-
+    # torch.cuda.synchronize()
+    # print((time.time()-t0)/iters*1e6)
 
 
 def test_normal_map_tree():
     code = F.create_normal_map()
-    values =code[:8].tolist() + code[-8:].tolist()
+    values = code[:8].tolist() + code[-8:].tolist()
     num_pivots = 1
-    #print(values)
-    while num_pivots <16:
-        idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
-        #print(idx)
+    # print(values)
+    while num_pivots < 16:
+        idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots))
+        # print(idx)
         num_pivots *= 2
         pivots = []
         for i in idx:
-            pivots.append((values[i-1]+values[i])/2)
-        #print(pivots)
+            pivots.append((values[i - 1] + values[i]) / 2)
+        # print(pivots)
 
 
 @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
-@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'])
-@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'])
+@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
+@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
-@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
+@pytest.mark.parametrize(
+    "quant_storage",
+    [torch.uint8, torch.float16, torch.bfloat16, torch.float32],
+    ids=describe_dtype,
+)
 def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
     for dim in [128, 256, 512, 1024]:
-    #for dim in [4*1024]:
-    #for dim in [1*16]:
+        # for dim in [4*1024]:
+        # for dim in [1*16]:
         errs1 = []
         errs2 = []
         errs3 = []
@@ -2171,38 +2071,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
         max_errs2 = []
         max_errs3 = []
 
-
         for i in range(100):
-            if kind == 'fc1':
-                A = torch.randn(1, dim, dtype=dtype, device='cuda')
-                B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
-            elif kind == 'fc2':
-                A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
-                B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
-            elif kind == 'attn':
-                A = torch.randn(1, dim, dtype=dtype, device='cuda')
-                B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
-            elif kind == 'attn_packed':
-                A = torch.randn(1, dim, dtype=dtype, device='cuda')
-                B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
-
-            qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage)
+            if kind == "fc1":
+                A = torch.randn(1, dim, dtype=dtype, device="cuda")
+                B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
+            elif kind == "fc2":
+                A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
+                B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
+            elif kind == "attn":
+                A = torch.randn(1, dim, dtype=dtype, device="cuda")
+                B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
+            elif kind == "attn_packed":
+                A = torch.randn(1, dim, dtype=dtype, device="cuda")
+                B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
+
+            qB, state = F.quantize_4bit(
+                B,
+                quant_type=storage_type,
+                compress_statistics=double_quant,
+                quant_storage=quant_storage,
+            )
             C3 = torch.matmul(A, B.t())
             C2 = F.gemv_4bit(A, qB.t(), state=state)
             A.requires_grad = True
             C1 = bnb.matmul_4bit(A, qB.t(), state)
 
-            err1 = (C1-C2).abs().float()
-            err2 = (C3-C2).abs().float()
-            err3 = (C3-C1).abs().float()
+            err1 = (C1 - C2).abs().float()
+            err2 = (C3 - C2).abs().float()
+            err3 = (C3 - C1).abs().float()
 
-            mag1 = torch.abs(C1).float()+1e-5
-            mag2 = torch.abs(C3).float()+1e-5
-            mag3 = torch.abs(C3).float()+1e-5
+            mag1 = torch.abs(C1).float() + 1e-5
+            mag2 = torch.abs(C3).float() + 1e-5
+            mag3 = torch.abs(C3).float() + 1e-5
 
-            relerr1 = err1/mag1
-            relerr2 = err2/mag2
-            relerr3 = err3/mag3
+            relerr1 = err1 / mag1
+            relerr2 = err2 / mag2
+            relerr3 = err3 / mag3
 
             max_err1 = err1.max()
             max_err2 = err2.max()
@@ -2220,34 +2124,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
             max_errs2.append(max_err2.item())
             max_errs3.append(max_err3.item())
 
-            c = int(C1.numel()*0.0014*(dim/256))+1
+            c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
 
             c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
-        err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
-        err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
-        err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
-        relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
-        relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
-        relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
-        maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
-        maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
-        maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
-        absratio = err2/err3
-        relratio = relerr2/relerr3
-        maxratio = relerr2/relerr3
+        err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
+        err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
+        err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
+        relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
+        relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
+        relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
+        maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
+        maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
+        maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
+        absratio = err2 / err3
+        relratio = relerr2 / relerr3
+        maxratio = relerr2 / relerr3
 
         # for debugging if the tests fails
         #
-        #print('='*80)
-        #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
-        #print(C1.flatten()[-20:])
-        #print(C2.flatten()[-20:])
-        #print(f'inference vs training abs: {err1}')
-        #print(f'inference vs training rel: {relerr1}')
-        #print(f'inference vs training max: {maxerr1}')
-        #print(f'inference vs training vs torch err ratio abs: {absratio}')
-        #print(f'inference vs training vs torch err ratio rel: {relratio}')
-        #print(f'inference vs training vs torch err ratio max: {maxratio}')
+        # print('='*80)
+        # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
+        # print(C1.flatten()[-20:])
+        # print(C2.flatten()[-20:])
+        # print(f'inference vs training abs: {err1}')
+        # print(f'inference vs training rel: {relerr1}')
+        # print(f'inference vs training max: {maxerr1}')
+        # print(f'inference vs training vs torch err ratio abs: {absratio}')
+        # print(f'inference vs training vs torch err ratio rel: {relratio}')
+        # print(f'inference vs training vs torch err ratio max: {maxratio}')
         if dtype == torch.float16:
             if dim <= 512:
                 assert err1 < 7e-5
@@ -2283,56 +2187,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
             assert relratio < 1.04 and relratio > 0.96
             assert maxratio < 1.02 and maxratio > 0.98
 
+
 @pytest.mark.skip("Row scale has some bugs for ampere")
 def test_managed():
-    n = 32*10
+    n = 32 * 10
     A = F.get_paged(n, n, dtype=torch.float32)
     B = F.get_paged(n, n, dtype=torch.uint8)
     B2 = F.get_paged(n, n, dtype=torch.float32)
     assert A.is_paged
     assert B.is_paged
-    assert A.page_deviceid==0
-    assert B.page_deviceid==0
+    assert A.page_deviceid == 0
+    assert B.page_deviceid == 0
     F.fill(A, 17.0)
     F.fill(B, 17)
     F.fill(B2, 2)
-    assert (A==17).sum().item() == n*n
-    assert (B==17).sum().item() == n*n
-    C = A*B.float()
-    assert (C==289).sum().item() == n*n
+    assert (A == 17).sum().item() == n * n
+    assert (B == 17).sum().item() == n * n
+    C = A * B.float()
+    assert (C == 289).sum().item() == n * n
     F._mul(A, B2)
     F._mul(A, B2)
     F._mul(A, B2)
-    assert (A==17*(2**3)).sum().item() == n*n
-   # F.prefetch_tensor(A)
-   # F.prefetch_tensor(B)
+    assert (A == 17 * (2**3)).sum().item() == n * n
+
+
+# F.prefetch_tensor(A)
+# F.prefetch_tensor(B)
 
 
-   # F.fill(B2, 17.0)
-   # F._mul(A, B2)
+# F.fill(B2, 17.0)
+# F._mul(A, B2)
 
-   # F.prefetch_tensor(A, to_cpu=True)
-   # F.prefetch_tensor(B, to_cpu=True)
-   # F.prefetch_tensor(B2, to_cpu=True)
-   # torch.cuda.synchronize()
+# F.prefetch_tensor(A, to_cpu=True)
+# F.prefetch_tensor(B, to_cpu=True)
+# F.prefetch_tensor(B2, to_cpu=True)
+# torch.cuda.synchronize()
 
-   # assert (A==17).sum().item() == n*n
+# assert (A==17).sum().item() == n*n
 
-   # torch.testing.assert_close(A, torch.ones(A.shape)*289)
+# torch.testing.assert_close(A, torch.ones(A.shape)*289)
 
 
-@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
+@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
-@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
+@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
 def test_gemv_eye_4bit(storage_type, dtype, double_quant):
     dims = 10
     torch.random.manual_seed(np.random.randint(0, 412424242))
     dims = get_test_dims(0, 8192, n=dims)
-    dims = [dim + (64-(dim % 64)) for dim in dims]
-    #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
+    dims = [dim + (64 - (dim % 64)) for dim in dims]
+    # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
     for dim in dims:
-        A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda')
-        B = torch.eye(dim, dtype=dtype, device='cuda')
+        A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
+        B = torch.eye(dim, dtype=dtype, device="cuda")
 
         qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
         C3 = torch.matmul(A, B.t())
@@ -2343,5 +2250,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
         torch.testing.assert_close(A, C3)
         torch.testing.assert_close(A, C1)
         torch.testing.assert_close(A, C2)
-        #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
-        #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
+        # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
+        # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
diff --git a/tests/test_generation.py b/tests/test_generation.py
index ef354d70a..911aa14da 100644
--- a/tests/test_generation.py
+++ b/tests/test_generation.py
@@ -10,56 +10,61 @@
 
 
 def get_4bit_config():
-  return transformers.BitsAndBytesConfig(
-    load_in_4bit=True,
-    load_in_8bit=False,
-    llm_int8_threshold=6.0,
-    llm_int8_has_fp16_weight=False,
-    bnb_4bit_compute_dtype=torch.float16,
-    bnb_4bit_use_double_quant=True,
-    bnb_4bit_quant_type='nf4',
-  )
+    return transformers.BitsAndBytesConfig(
+        load_in_4bit=True,
+        load_in_8bit=False,
+        llm_int8_threshold=6.0,
+        llm_int8_has_fp16_weight=False,
+        bnb_4bit_compute_dtype=torch.float16,
+        bnb_4bit_use_double_quant=True,
+        bnb_4bit_quant_type="nf4",
+    )
 
 
 def get_model_and_tokenizer(config):
     model_name_or_path, quant_type = config
     bnb_config = get_4bit_config()
-    if quant_type == '16bit':
+    if quant_type == "16bit":
         bnb_config.load_in_4bit = False
     else:
-        bnb_config.bnb_4bit_quant_type= quant_type
-    model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path,
+        bnb_config.bnb_4bit_quant_type = quant_type
+    model = transformers.AutoModelForCausalLM.from_pretrained(
+        model_name_or_path,
         quantization_config=bnb_config,
-        max_memory={0:'48GB'},
-        device_map='auto',
-        torch_dtype=torch.bfloat16
-        ).eval()
+        max_memory={0: "48GB"},
+        device_map="auto",
+        torch_dtype=torch.bfloat16,
+    ).eval()
 
     tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
 
     return model, tokenizer
 
+
 def get_prompt_for_generation_eval(text, add_roles=True):
     description = (
         "A chat between a curious human and an artificial intelligence assistant. "
         "The assistant gives helpful, detailed, and polite answers to the user's questions."
     )
     if add_roles:
-        prompt = f'{description} ### Human: {text} ### Assistant:'
+        prompt = f"{description} ### Human: {text} ### Assistant:"
     else:
-        prompt = f'{description} {text}'
+        prompt = f"{description} {text}"
     return prompt
 
+
 def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval):
     text = prompt_func(text)
-    inputs = tokenizer(text, return_tensors="pt").to('cuda:0')
-    outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config)
+    inputs = tokenizer(text, return_tensors="pt").to("cuda:0")
+    outputs = model.generate(inputs=inputs["input_ids"], generation_config=generation_config)
     return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
-models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7']
-dtypes = ['nf4', 'fp4']
 
-@pytest.fixture(scope='session', params=product(models, dtypes))
+models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"]
+dtypes = ["nf4", "fp4"]
+
+
+@pytest.fixture(scope="session", params=product(models, dtypes))
 def model_and_tokenizer(request):
     model, tokenizer = get_model_and_tokenizer(request.param)
     yield request.param, model, tokenizer
@@ -81,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
     )
     generation_config.max_new_tokens = 20
 
-
-    #text = 'Please write down the first 50 digits of pi.'
-    #text = get_prompt_for_generation_eval(text)
-    #text += ' Sure, here the first 50 digits of pi: 3.14159'
+    # text = 'Please write down the first 50 digits of pi.'
+    # text = get_prompt_for_generation_eval(text)
+    # text += ' Sure, here the first 50 digits of pi: 3.14159'
     n_cases = 6
-    text = '3.14159'
-    if hasattr(model.config, 'quantization_config'):
+    text = "3.14159"
+    if hasattr(model.config, "quantization_config"):
         model.config.quantization_config.bnb_4bit_compute_dtype = dtype
         model.config.quantization_config.bnb_4bit_use_double_quant = DQ
 
     if not inference_kernel:
-        text = [text]*n_cases
-    inputs = tokenizer(text, return_tensors="pt").to('cuda:0')
-    x = inputs['input_ids']
+        text = [text] * n_cases
+    inputs = tokenizer(text, return_tensors="pt").to("cuda:0")
+    x = inputs["input_ids"]
     outputs = []
     if inference_kernel:
         for i in range(n_cases):
@@ -105,15 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
         outputs = model.generate(x, generation_config=generation_config)
         outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
 
-
     assert len(outputs) == n_cases
     failure_count = 0
     for i in range(n_cases):
-        if not outputs[i][:len(str(math.pi))] == str(math.pi):
+        if not outputs[i][: len(str(math.pi))] == str(math.pi):
             failure_count += 1
-    failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4)
+    failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4
     if failure_count > failure_max:
         print(math.pi)
         for out in outputs:
             print(out)
-        raise ValueError(f'Failure count: {failure_count}/{n_cases}')
+        raise ValueError(f"Failure count: {failure_count}/{n_cases}")
diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py
index 567e1a466..bbbd05335 100644
--- a/tests/test_linear4bit.py
+++ b/tests/test_linear4bit.py
@@ -28,9 +28,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
     device = "cuda"
     layer_shape = (300, 400)
 
-    linear = torch.nn.Linear(
-        *layer_shape, dtype=original_dtype, device="cpu"
-    )  # original layer
+    linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu")  # original layer
 
     # Quantizing original layer
     linear_q = bnb.nn.Linear4bit(
@@ -42,9 +40,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
         quant_type=quant_type,
         device="meta",
     )
-    new_weight = bnb.nn.Params4bit(
-        data=linear.weight, quant_type=quant_type, requires_grad=False
-    )
+    new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
     linear_q.weight = new_weight
     if bias:
         linear_q.bias = torch.nn.Parameter(linear.bias)
@@ -172,7 +168,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
         target_compression = (
             0.143 if original_dtype == torch.float32 else 0.29
         )  # these numbers get lower as weight shape increases
-        ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
+        ratio_error_msg = (
+            f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
+        )
         assert size_ratio < target_compression, ratio_error_msg
 
 
diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py
index edc3409cd..4b62abd6d 100644
--- a/tests/test_linear8bitlt.py
+++ b/tests/test_linear8bitlt.py
@@ -19,6 +19,7 @@
 # contributed by Alex Borzunov, see:
 # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
 
+
 @pytest.mark.skipif(
     not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
     reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
@@ -50,7 +51,9 @@ def test_linear_no_igemmlt():
     linear_custom.state.force_no_igemmlt = True
 
     linear_custom.weight = bnb.nn.Int8Params(
-        linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
+        linear.weight.data.clone(),
+        requires_grad=False,
+        has_fp16_weights=False,
     ).to(linear.weight.dtype)
     linear_custom.bias = linear.bias
     linear_custom = linear_custom.cuda()
@@ -77,7 +80,14 @@ def test_linear_no_igemmlt():
 @pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt"))
 @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
 @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
-def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda):
+def test_linear_serialization(
+    has_fp16_weights,
+    serialize_before_forward,
+    deserialize_before_cuda,
+    force_no_igemmlt,
+    save_before_forward,
+    load_before_cuda,
+):
     linear = torch.nn.Linear(32, 96)
     x = torch.randn(3, 32, dtype=torch.half)
 
@@ -92,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
         linear_custom.state.force_no_igemmlt = True
 
     linear_custom.weight = bnb.nn.Int8Params(
-        linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
+        linear.weight.data.clone(),
+        requires_grad=has_fp16_weights,
+        has_fp16_weights=has_fp16_weights,
     )
     linear_custom.bias = linear.bias
     linear_custom = linear_custom.cuda()
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 674620e29..db4d72410 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -19,12 +19,18 @@ class MLP8bit(torch.nn.Module):
     def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
         super().__init__()
         self.fc1 = bnb.nn.Linear8bitLt(
-            dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
-            threshold=threshold
+            dim1,
+            dim2,
+            has_fp16_weights=has_fp16_weights,
+            memory_efficient_backward=memory_efficient_backward,
+            threshold=threshold,
         )
         self.fc2 = bnb.nn.Linear8bitLt(
-            dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
-            threshold=threshold
+            dim2,
+            dim1,
+            has_fp16_weights=has_fp16_weights,
+            memory_efficient_backward=memory_efficient_backward,
+            threshold=threshold,
         )
 
     def forward(self, x):
@@ -52,9 +58,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
 class LinearFunction(torch.autograd.Function):
     @staticmethod
     def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
-        round_func = (
-            LinearFunction.round_stoachastic if stochastic else torch.round
-        )
+        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
         norm = math.sqrt(math.pi) / math.sqrt(2.0)
         # std = torch.abs(x).mean()*norm
         std = torch.std(x)
@@ -122,9 +126,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype):
         return x.to(dtype)
 
     def get_8bit_linear(x, stochastic=False):
-        round_func = (
-            LinearFunction.round_stoachastic if stochastic else torch.round
-        )
+        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
         max1 = torch.abs(x).max()
         x = x / max1 * 127
         x = round_func(x) / 127 * max1
@@ -133,9 +135,7 @@ def get_8bit_linear(x, stochastic=False):
 
     @staticmethod
     def get_8bit_vector_wise(x, dim, stochastic=False):
-        round_func = (
-            LinearFunction.round_stoachastic if stochastic else torch.round
-        )
+        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
         max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
         max1[max1 == 0] = 1.0
         x = (x * 127) / max1
@@ -219,9 +219,7 @@ def forward(ctx, x, weight, bias=None, args=None):
             weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
             x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
             outputq = bnb.functional.igemm(x8, weight8.t())
-            output = LinearFunction.dequant(
-                outputq, S1, S2, x.dtype, args.quant_type
-            )
+            output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
             # if torch.rand(1) < 0.01:
             # output32 = torch.matmul(x, weight.t())
             # err = torch.abs(output-output32).float()
@@ -250,37 +248,25 @@ def backward(ctx, grad_output):
         # weight and x are already 8bit
         # -> transform grad_output to 8-bit
         if args.use_8bit_training == "forward+wgrad":
-            grad_output8, S1 = LinearFunction.quant(
-                grad_output, args.quant_type, dim=[0, 1]
-            )
+            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
             x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
             grad_weight8 = bnb.functional.igemm(grad_output8, x8)
-            grad_weight = LinearFunction.dequant(
-                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
-            )
+            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
 
             # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
 
             grad_input = grad_output.matmul(weight)
         elif args.use_8bit_training == "full":
-            grad_output8, S1 = LinearFunction.quant(
-                grad_output, args.quant_type, dim=[0, 1]
-            )
+            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
             x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
             grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
             bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
-            grad_weight = LinearFunction.dequant(
-                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
-            )
+            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
 
-            grad_output8, S1 = LinearFunction.quant(
-                grad_output, args.quant_type, dim=2
-            )
+            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
             weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
             grad_input8 = bnb.functional.igemm(grad_output8, weight8)
-            grad_input = LinearFunction.dequant(
-                grad_input8, S1, S3, grad_output.dtype, args.quant_type
-            )
+            grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
 
         else:
             grad_input = grad_output.matmul(weight)
@@ -356,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient():
             opt1.zero_grad(True)
             opt2.step()
             opt2.zero_grad(True)
-            assert_all_approx_close(
-                l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
-            )
-            assert_all_approx_close(
-                l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
-            )
+            assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
+            assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
             # we do this copy because otherwise we have small divergences over time that add up
             l1[0].weight.data.copy_(l2[0].weight.data)
             l1[1].weight.data.copy_(l2[1].weight.data)
@@ -375,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient():
 @pytest.mark.parametrize("threshold", [0.0, 2.0])
 @pytest.mark.parametrize("memory_efficient_backward", [False])
 def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
-    l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
+    l1 = (
+        bnb.nn.Linear8bitLt(
+            32,
+            64,
+            threshold=threshold,
+            has_fp16_weights=False,
+            memory_efficient_backward=memory_efficient_backward,
+        )
+        .cuda()
+        .half()
+    )
     assert l1.weight.dtype == torch.int8
 
     l1.eval()
@@ -397,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
         if threshold > 0:
             assert mlp.fc2.state.idx is not None
 
-    mlp = (
-        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
-        .cuda()
-        .half()
-    )
+    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
     assert mlp.fc1.weight.dtype == torch.int8
     assert mlp.fc2.weight.dtype == torch.int8
 
@@ -414,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
         if threshold > 0:
             assert mlp.fc2.state.idx is not None
 
-    mlp = (
-        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
-        .half()
-        .cuda()
-    )
+    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
 
     for i in range(100):
         b1 = torch.randn(16, 8, 32, device="cuda").half()
@@ -431,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
     assert mlp.fc1.weight.dtype == torch.int8
     assert mlp.fc2.weight.dtype == torch.int8
 
-    mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
+    mlp = (
+        MLP8bit(
+            32,
+            64,
+            threshold=threshold,
+            has_fp16_weights=False,
+            memory_efficient_backward=memory_efficient_backward,
+        )
+        .half()
+        .to("cuda")
+    )
 
     for i in range(100):
         b1 = torch.randn(16, 8, 32, device="cuda").half()
@@ -447,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
     assert mlp.fc2.weight.device.type == "cuda"
 
     mlp = MLP8bit(
-            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
-        )
+        32,
+        64,
+        threshold=threshold,
+        has_fp16_weights=False,
+        memory_efficient_backward=memory_efficient_backward,
+    )
     w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
     mlp = mlp.cuda().half()  # and this line triggers quantization
 
@@ -489,7 +487,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
         lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
         bnb.nn.LinearFP4,
     ],
-    ids=['Int8Lt', 'FP4'],
+    ids=["Int8Lt", "FP4"],
 )
 def test_linear_kbit_fp32_bias(module):
     # casts model to fp16 -> int8 automatically
@@ -544,7 +542,7 @@ def test_kbit_backprop(module):
     kbit[1].bias.detach().copy_(ref[1].bias)
     ref = ref.half().cuda()
     kbit = kbit.half().cuda()
-    kbit = kbit.half().to('cuda')
+    kbit = kbit.half().to("cuda")
 
     errs1 = []
     errs2 = []
@@ -562,10 +560,10 @@ def test_kbit_backprop(module):
         bgrad1 = ref[0].bias.grad
         bgrad2 = kbit[0].bias.grad
 
-        err1 = (out1-out2).abs().float()
-        err2 = (grad1-grad2).abs().float()
-        relerr1 = (err1/(out1.abs().float()+1e-9))
-        relerr2 = (err2/(grad1.abs().float()+1e-9))
+        err1 = (out1 - out2).abs().float()
+        err2 = (grad1 - grad2).abs().float()
+        relerr1 = err1 / (out1.abs().float() + 1e-9)
+        relerr2 = err2 / (grad1.abs().float() + 1e-9)
         errs1.append(err1.mean().item())
         errs2.append(err2.mean().item())
         relerrs1.append(relerr1.mean().item())
@@ -582,20 +580,20 @@ def test_kbit_backprop(module):
 
         assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
         assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
-    #print('out', sum(errs1)/len(errs1))
-    #print('grad', sum(errs2)/len(errs2))
-    #print('rel out', sum(relerrs1)/len(relerrs1))
-    #print('rel grad', sum(relerrs2)/len(relerrs2))
+    # print('out', sum(errs1)/len(errs1))
+    # print('grad', sum(errs2)/len(errs2))
+    # print('rel out', sum(relerrs1)/len(relerrs1))
+    # print('rel grad', sum(relerrs2)/len(relerrs2))
 
-def test_fp8linear():
 
+def test_fp8linear():
     b = 10
     h = 1024
     inp = torch.randn(b, h).cuda()
-    fp32 = torch.nn.Linear(h, h*2).cuda()
-    fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
-    fp32b = torch.nn.Linear(h*2, h).cuda()
-    fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
+    fp32 = torch.nn.Linear(h, h * 2).cuda()
+    fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
+    fp32b = torch.nn.Linear(h * 2, h).cuda()
+    fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
 
     fp8.weight.data.copy_(fp32.weight.data)
     fp8.bias.data.copy_(fp32.bias.data)
@@ -605,34 +603,34 @@ def test_fp8linear():
     a = fp32b(torch.nn.functional.gelu(fp32(inp)))
     b = fp8b(torch.nn.functional.gelu(fp8(inp)))
 
-    err = (a-b).abs().mean()
+    err = (a - b).abs().mean()
 
     a.mean().backward()
     b.mean().backward()
 
-    graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
-    bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()
+    graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
+    bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
 
     assert err < 0.05
     assert graderr < 0.00002
     assert bgraderr < 0.00002
 
+
 def test_4bit_warnings():
     dim1 = 64
 
-    with pytest.warns(UserWarning, match=r'inference or training'):
+    with pytest.warns(UserWarning, match=r"inference or training"):
         net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
         net = net.cuda()
         inp = torch.rand(10, dim1).cuda().half()
         net(inp)
-    with pytest.warns(UserWarning, match=r'inference.'):
+    with pytest.warns(UserWarning, match=r"inference."):
         net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
         net = net.cuda()
         inp = torch.rand(1, dim1).cuda().half()
         net(inp)
 
     with pytest.warns(UserWarning) as record:
-
         net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
         net = net.cuda()
         inp = torch.rand(10, dim1).cuda().half()
diff --git a/tests/test_optim.py b/tests/test_optim.py
index 9395b8820..d8c46e415 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -16,6 +16,7 @@
 
 k = 20
 
+
 def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
     idx = torch.isclose(a, b, rtol=rtol, atol=atol)
     error_count = (idx == 0).sum().item()
@@ -33,6 +34,7 @@ def get_temp_dir():
 def rm_path(path):
     shutil.rmtree(path)
 
+
 str2optimizers = {}
 str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
 str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
@@ -66,8 +68,14 @@ def rm_path(path):
 )
 
 str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
-str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
-str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
+str2optimizers["paged_adamw8bit_blockwise"] = (
+    torch.optim.AdamW,
+    lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
+)
+str2optimizers["paged_adam8bit_blockwise"] = (
+    torch.optim.Adam,
+    lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
+)
 str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
 str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
 str2optimizers["momentum8bit_blockwise"] = (
@@ -90,9 +98,18 @@ def rm_path(path):
 str2statenames["rmsprop"] = [("square_avg", "state1")]
 str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
 str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
-str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
-str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
-str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
+str2statenames["adam8bit_blockwise"] = [
+    ("exp_avg", "state1", "qmap1", "absmax1"),
+    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
+]
+str2statenames["paged_adam8bit_blockwise"] = [
+    ("exp_avg", "state1", "qmap1", "absmax1"),
+    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
+]
+str2statenames["paged_adamw8bit_blockwise"] = [
+    ("exp_avg", "state1", "qmap1", "absmax1"),
+    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
+]
 str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
 str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
 str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
@@ -101,7 +118,7 @@ def rm_path(path):
 str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
 str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
 
-optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
+optimizer_names_32bit = ["adam", "momentum", "rmsprop", "paged_adamw", "paged_adam", "lion", "paged_lion"]
 
 
 @pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
@@ -109,7 +126,7 @@ def rm_path(path):
 @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
 @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
 def test_optimizer32bit(dim1, dim2, gtype, optim_name):
-    if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']:
+    if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
         pytest.skip()
     if dim1 == 1 and dim2 == 1:
         return
@@ -161,9 +178,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
             for name1, name2 in str2statenames[optim_name]:
                 # since Lion can have pretty noisy updates where things lie at the boundary
                 # allow up to 10 errors for Lion
-                assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2],
-                                         atol=atol, rtol=rtol,
-                                         max_error_count=10)
+                assert_most_approx_close(
+                    torch_optimizer.state[p1][name1],
+                    bnb_optimizer.state[p2][name2],
+                    atol=atol,
+                    rtol=rtol,
+                    max_error_count=10,
+                )
 
         if gtype != torch.float32:
             # the adam buffers should also be close because they are 32-bit
@@ -193,13 +214,9 @@ def test_global_config(dim1, dim2, gtype):
     eps = 1e-8
 
     bnb.optim.GlobalOptimManager.get_instance().initialize()
-    bnb.optim.GlobalOptimManager.get_instance().override_config(
-        p3, "optim_bits", 8
-    )
+    bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
 
-    bnb.optim.GlobalOptimManager.get_instance().register_parameters(
-        [p1, p2, p3]
-    )
+    bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
     p1 = p1.cuda()
     p2 = p2.cuda()
     p3 = p3.cuda()
@@ -242,7 +259,8 @@ def test_global_config(dim1, dim2, gtype):
 @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
 @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
 def test_optimizer8bit(dim1, dim2, gtype, optim_name):
-    if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
+    if gtype == torch.bfloat16 and optim_name not in ["adam8bit_blockwise", "lion8bit_blockwise"]:
+        pytest.skip()
     if dim1 == 1 and dim2 == 1:
         return
     p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
@@ -294,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
                     absmax=bnb_optimizer.state[p2][max_val],
                     A=bnb_optimizer.state[p2][name2],
                 )
-            num_not_close = (
-                torch.isclose(
-                    torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
-                )
-                == 0
-            )
-            #assert num_not_close.sum().item() < 20
+            num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
+            # assert num_not_close.sum().item() < 20
             dequant_states.append(s1.clone())
 
         err = torch.abs(p1 - p2)
-        relerr = err / (torch.abs(p1)+1e-9)
+        relerr = err / (torch.abs(p1) + 1e-9)
         if g.dtype == torch.bfloat16:
             assert err.mean() < 0.00015
             assert relerr.mean() < 0.0016
@@ -316,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
         relerrors.append(relerr.mean().item())
 
         if i % 10 == 0 and i > 0:
-            for (name1, name2, qmap, max_val), s in zip(
-                str2statenames[optim_name], dequant_states
-            ):
+            for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
                 s1cpy = s.clone()
                 raws1cpy = bnb_optimizer.state[p2][name2].clone()
                 qmap1 = bnb_optimizer.state[p2][qmap].clone()
@@ -348,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
                     )
                 torch.testing.assert_close(s1cpy, s1)
 
-                num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
+                num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
                 assert num_not_close.sum().item() < 20
             # since Lion can have pretty noisy updates where things lie at the boundary
             # allow up to 5 errors for Lion
@@ -395,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
 
     for i in range(50):
         step += 1
-        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
-            0.01 * i
-        )
+        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
         g2 = g1.clone()
         p2.grad = g2
 
-        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
-            g1, gnorm_vec, step, 5
-        )
+        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
         g1 = (g1.float() * gnorm_scale).to(gtype)
         p1.grad = g1
 
@@ -497,8 +504,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
 
 @pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
 @pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
-@pytest.mark.parametrize("optim_name", ['paged_adamw'], ids=id_formatter("optim_name"))
-@pytest.mark.parametrize("mode", ['bnb'], ids=id_formatter("mode"))
+@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
+@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
 @pytest.mark.benchmark
 def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
     layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
@@ -506,24 +513,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
     layers1 = layers1.cuda()
 
     large_tensor = None
-    if mode == 'torch':
+    if mode == "torch":
         optim = str2optimizers[optim_name][0](layers1.parameters())
     else:
         optim = str2optimizers[optim_name][1](layers1.parameters())
         # 12 GB
-        large_tensor = torch.empty((int(4.5e9),), device='cuda')
+        large_tensor = torch.empty((int(4.5e9),), device="cuda")
 
     torch.cuda.synchronize()
     time.sleep(5)
 
     num_batches = 5
-    batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype)
-    lbls = torch.randint(0, 10, size=(num_batches,128)).cuda()
+    batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
+    lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
 
     for i in range(num_batches):
         print(i)
         b = batches[i]
-        if i ==2:
+        if i == 2:
             torch.cuda.synchronize()
             t0 = time.time()
 
diff --git a/tests/test_triton.py b/tests/test_triton.py
index 218a533d5..3624fb5e9 100644
--- a/tests/test_triton.py
+++ b/tests/test_triton.py
@@ -7,15 +7,18 @@
 from tests.helpers import TRUE_FALSE
 
 
-@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
-                    reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
+@pytest.mark.skipif(
+    not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
+    reason="This test requires triton and a GPU with compute capability 8.0 or higher.",
+)
 @pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE)
 def test_switchback(vector_wise_quantization):
     for dim in [83]:
         for batch in [13]:
-
             standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
-            switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
+            switchback = (
+                SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
+            )
             baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
             switchback.weight.data.copy_(standard.weight)
             switchback.bias.data.copy_(standard.bias)
@@ -38,23 +41,23 @@ def test_switchback(vector_wise_quantization):
 
             err_sb = (out_standard - out_sb).abs().mean()
             err_baseline = (out_standard - out_baseline).abs().mean()
-            print('OUT', err_sb, err_baseline)
+            print("OUT", err_sb, err_baseline)
             assert err_sb < 2 * err_baseline
 
             err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
             err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
 
-            print('GW2', err_sb,  err_baseline)
+            print("GW2", err_sb, err_baseline)
             assert err_sb < 2 * err_baseline
 
             err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
             err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
 
-            print('GW1', err_sb,  err_baseline)
+            print("GW1", err_sb, err_baseline)
             assert err_sb < 2 * err_baseline
 
             err_sb = (x1.grad - x2.grad).abs().mean()
             err_baseline = (x1.grad - x3.grad).abs().mean()
 
-            print('GX1', err_sb, err_baseline)
+            print("GX1", err_sb, err_baseline)
             assert err_sb < 2 * err_baseline