Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff fixes #984

Merged
merged 16 commits into from
Jan 30, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Ruff: autofix I and F401
akx committed Jan 29, 2024
commit 6246f18bb62e549806d3ced3249ba136c2839e0d
6 changes: 2 additions & 4 deletions benchmarking/switchback/make_plot_with_jsonl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import pandas as pd

cmap=plt.get_cmap('cool')

20 changes: 14 additions & 6 deletions benchmarking/switchback/speed_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import json

import time

import torch
import torch.nn as nn

from bitsandbytes.triton.int8_matmul_mixed_dequantize import (
int8_matmul_mixed_dequantize,
)
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import (
int8_matmul_rowwise_dequantize,
)
from bitsandbytes.triton.quantize_columnwise_and_transpose import (
quantize_columnwise_and_transpose,
)
from bitsandbytes.triton.quantize_global import (
quantize_global,
quantize_global_transpose,
)
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize

# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.

4 changes: 2 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
@@ -3,14 +3,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import cuda_setup, utils, research
from . import cuda_setup, research, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
matmul,
matmul_4bit,
matmul_cublas,
mm_cublas,
matmul_4bit
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules
9 changes: 3 additions & 6 deletions bitsandbytes/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
import sys
from os.path import isdir
import shlex
import subprocess

from warnings import warn
import sys
from typing import Tuple
from os.path import isdir
from warnings import warn

import torch

@@ -88,10 +87,8 @@ def print_debug_info() -> None:


from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
from .cuda_setup.env_vars import to_be_ignored
from .cuda_setup.main import get_compute_capabilities


print_header("OTHER")
print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}")
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._functions import undo_layout, get_inverse_transform_indices
from ._functions import get_inverse_transform_indices, undo_layout
6 changes: 3 additions & 3 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import operator
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
from typing import Tuple, Optional, List
import operator
from typing import Optional, Tuple
import warnings
from warnings import warn

import torch
7 changes: 2 additions & 5 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import ctypes as ct
import os
import torch

from pathlib import Path
from warnings import warn

from bitsandbytes.cuda_setup.main import CUDASetup
import torch

from bitsandbytes.cuda_setup.main import CUDASetup

setup = CUDASetup.get_instance()
if setup.initialized != True:
10 changes: 5 additions & 5 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
@@ -17,15 +17,15 @@
"""

import ctypes as ct
import os
import errno
import os
from pathlib import Path
import platform
import torch
from typing import Set, Union
from warnings import warn
from itertools import product

from pathlib import Path
from typing import Set, Union
import torch

from .env_vars import get_potentially_lib_path_containing_env_vars

# these are the most common libs names
12 changes: 5 additions & 7 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
@@ -3,17 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes as ct
from functools import reduce # Required in Python 3
import itertools
import operator
import random
import torch
import itertools
import math
import numpy as np
from typing import Any, Dict, Tuple

from functools import reduce # Required in Python 3
from typing import Tuple, Any, Dict
import numpy as np
import torch
from torch import Tensor

from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict

from .cextension import COMPILED_WITH_CUDA, lib
20 changes: 18 additions & 2 deletions bitsandbytes/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -2,5 +2,21 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb, Embedding
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear
from .modules import (
Embedding,
Int8Params,
Linear4bit,
Linear8bitLt,
LinearFP4,
LinearNF4,
OutlierAwareLinear,
Params4bit,
StableEmbedding,
SwitchBackLinearBnb,
)
from .triton_based_modules import (
StandardLinear,
SwitchBackLinear,
SwitchBackLinearGlobal,
SwitchBackLinearVectorwise,
)
8 changes: 4 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
@@ -3,17 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional, TypeVar, Union, overload

import warnings

import torch
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
import torch.nn.functional as F

import bitsandbytes as bnb
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
from bitsandbytes.utils import OutlierTracer

T = TypeVar("T", bound="torch.nn.Module")

24 changes: 16 additions & 8 deletions bitsandbytes/nn/triton_based_modules.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import torch
import torch.nn as nn
import time
from functools import partial

from bitsandbytes.triton.triton_utils import is_triton_available
import torch
import torch.nn as nn

from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.int8_matmul_mixed_dequantize import (
int8_matmul_mixed_dequantize,
)
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import (
int8_matmul_rowwise_dequantize,
)
from bitsandbytes.triton.quantize_columnwise_and_transpose import (
quantize_columnwise_and_transpose,
)
from bitsandbytes.triton.quantize_global import (
quantize_global,
quantize_global_transpose,
)
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
from bitsandbytes.triton.triton_utils import is_triton_available


class _switchback_global(torch.autograd.Function):
11 changes: 9 additions & 2 deletions bitsandbytes/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -7,10 +7,17 @@

from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit
from .adamw import (
AdamW,
AdamW8bit,
AdamW32bit,
PagedAdamW,
PagedAdamW8bit,
PagedAdamW32bit,
)
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
from .sgd import SGD, SGD8bit, SGD32bit
1 change: 0 additions & 1 deletion bitsandbytes/optim/adamw.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@
from bitsandbytes.optim.optimizer import Optimizer2State



class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
1 change: 1 addition & 0 deletions bitsandbytes/optim/lion.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State


class Lion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
3 changes: 1 addition & 2 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
@@ -2,8 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import abc as container_abcs
from collections import defaultdict
from collections import abc as container_abcs, defaultdict
from copy import deepcopy
from itertools import chain

2 changes: 1 addition & 1 deletion bitsandbytes/research/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import nn
from .autograd._functions import (
switchback_bnb,
matmul_fp8_global,
matmul_fp8_mixed,
switchback_bnb,
)
6 changes: 2 additions & 4 deletions bitsandbytes/research/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from functools import reduce # Required in Python 3
import operator
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3

import torch

from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatmulLtState
import bitsandbytes.functional as F

from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler


# math.prod not compatible with python < 3.8
def prod(iterable):
2 changes: 1 addition & 1 deletion bitsandbytes/research/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .modules import LinearFP8Mixed, LinearFP8Global
from .modules import LinearFP8Global, LinearFP8Mixed
7 changes: 2 additions & 5 deletions bitsandbytes/research/nn/modules.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Optional, TypeVar, Union, overload
from typing import TypeVar

import torch
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
from torch import nn

import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims

T = TypeVar("T", bound="torch.nn.Module")

4 changes: 2 additions & 2 deletions bitsandbytes/triton/dequantize_rowwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

import torch
import time

from bitsandbytes.triton.triton_utils import is_triton_available

if not is_triton_available():
@@ -9,7 +10,6 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None

import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

# rowwise quantize

1 change: 1 addition & 0 deletions bitsandbytes/triton/int8_matmul_mixed_dequantize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

from bitsandbytes.triton.triton_utils import is_triton_available

if not is_triton_available():
4 changes: 2 additions & 2 deletions bitsandbytes/triton/quantize_columnwise_and_transpose.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

import torch
import time

from bitsandbytes.triton.triton_utils import is_triton_available

if not is_triton_available():
@@ -9,7 +10,6 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): return None

import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

# This kernel does fused columnwise quantization and transpose.

5 changes: 2 additions & 3 deletions bitsandbytes/triton/quantize_global.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math

import torch
import time

from bitsandbytes.triton.triton_utils import is_triton_available

if not is_triton_available():
@@ -10,7 +10,6 @@ def quantize_global(x: torch.Tensor): return None

import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

# global quantize
@triton.autotune(
3 changes: 1 addition & 2 deletions bitsandbytes/triton/quantize_rowwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math

import torch
import time

from bitsandbytes.triton.triton_utils import is_triton_available

@@ -10,7 +10,6 @@ def quantize_rowwise(x: torch.Tensor): return None

import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

# rowwise quantize

Loading