Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 22 additions & 88 deletions benchmark/scripts/benchmark_dyt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import sys

import torch
import triton

from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from utils import run_memory_benchmark
from utils import run_speed_benchmark

from liger_kernel.utils import infer_device

Expand All @@ -18,98 +17,33 @@
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
def _setup_dyt(input: SingleBenchmarkRunInput):
"""Create input tensor and DyT layer from benchmark config."""
from test.transformers.test_dyt import LigerDyT
from test.transformers.test_dyt import TorchDyT

cfg = input.extra_benchmark_config
hidden_size = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
BT = extra_benchmark_config["BT"]
beta = extra_benchmark_config["beta"]
dtype = extra_benchmark_config["dtype"]

x_shape = (BT, hidden_size)
torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)

x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

def fwd():
if provider == "liger":
return triton_dyt(x)
elif provider == "torch":
return torch_dyt(x)
elif provider == "torch_compile":
return torch_compile_dyt(x)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
elif mode == "full":
x = torch.randn(cfg["BT"], hidden_size, device=device, dtype=cfg["dtype"], requires_grad=True)
if input.kernel_provider == "liger":
layer = LigerDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device)
elif input.kernel_provider == "torch":
layer = TorchDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device)
elif input.kernel_provider == "torch_compile":
layer = torch.compile(TorchDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device))
else:
raise ValueError(f"Invalid provider: {input.kernel_provider} for DyT")
return x, layer

def full():
y = fwd()
y.backward(dy)

ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_dyt(input)
return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x])


def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.transformers.test_dyt import LigerDyT
from test.transformers.test_dyt import TorchDyT

hidden_size = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
BT = extra_benchmark_config["BT"]
beta = extra_benchmark_config["beta"]
dtype = extra_benchmark_config["dtype"]

x_shape = (BT, hidden_size)
torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)

x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

def fwd():
if provider == "liger":
return triton_dyt(x)
elif provider == "torch":
return torch_dyt(x)
elif provider == "torch_compile":
return torch_compile_dyt(x)

def full():
y = fwd()
y.backward(dy, retain_graph=True)

mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
x, layer = _setup_dyt(input)
return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode)


if __name__ == "__main__":
Expand All @@ -128,14 +62,14 @@ def full():

run_benchmarks(
bench_test_fn=bench_speed_dyt,
kernel_operation_modes=["forward", "backward", "full"],
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_dyt,
kernel_operation_modes=["full"],
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
Expand Down
191 changes: 67 additions & 124 deletions benchmark/scripts/benchmark_geglu.py
Original file line number Diff line number Diff line change
@@ -1,163 +1,106 @@
import math

import torch
import triton

from benchmark_model_configs import DEFAULT_MODEL_CONFIG
from benchmark_model_configs import MODEL_REGISTRY
from benchmark_model_configs import compute_benchmark_shape
from benchmark_model_configs import estimate_kernel_bytes_per_token
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from utils import run_memory_benchmark
from utils import run_speed_benchmark

from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.utils import infer_device
from liger_kernel.utils import get_total_gpu_memory

device = infer_device()


def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
seq_len = input.x
bsz = input.extra_benchmark_config["bsz"]
hidden_size = input.extra_benchmark_config["hidden_size"]
intermediate_size = input.extra_benchmark_config["intermediate_size"]
hidden_act = input.extra_benchmark_config["hidden_act"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

def _setup_geglu(input: SingleBenchmarkRunInput):
"""Create input tensor and GEGLU layer from benchmark config."""
cfg = input.extra_benchmark_config
llama_config = LlamaConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_size=cfg["hidden_size"],
intermediate_size=cfg["intermediate_size"],
hidden_act=cfg["hidden_act"],
)

x_shape = (bsz, seq_len, hidden_size)

# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)

if provider == "liger":
layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "huggingface":
layer = LlamaMLP(config=llama_config).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for GEGLU")

def fwd():
return layer(x)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
grad_to_none=[x],
rep=10,
quantiles=QUANTILES,
)
elif mode == "backward":
do = torch.randn_like(x)
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(do, retain_graph=True),
grad_to_none=[x],
rep=10,
quantiles=QUANTILES,
)
x = torch.randn(
cfg["bsz"],
input.x,
cfg["hidden_size"],
device=device,
dtype=cfg["dtype"],
requires_grad=True,
)
if input.kernel_provider == "liger":
layer = LigerGEGLUMLP(config=llama_config).to(device).to(cfg["dtype"])
elif input.kernel_provider == "huggingface":
layer = LlamaMLP(config=llama_config).to(device).to(cfg["dtype"])
else:
raise ValueError(f"Invalid provider: {input.kernel_provider} for GEGLU")
return x, layer

def full():
y = fwd()
y.backward(torch.randn_like(y), retain_graph=True)

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[x],
rep=10,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)

def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_geglu(input)
return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x])


def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
seq_len = input.x
bsz = input.extra_benchmark_config["bsz"]
hidden_size = input.extra_benchmark_config["hidden_size"]
intermediate_size = input.extra_benchmark_config["intermediate_size"]
hidden_act = input.extra_benchmark_config["hidden_act"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
x, layer = _setup_geglu(input)
return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode)

llama_config = LlamaConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
)

x_shape = (bsz, seq_len, hidden_size)
# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
if __name__ == "__main__":
args = parse_benchmark_script_args()

if provider == "liger":
layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "huggingface":
layer = LlamaMLP(config=llama_config).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for GEGLU")

def fwd():
return layer(x)

def full():
y = fwd()
y.backward(torch.randn_like(y), retain_graph=True)

if mode == "forward":
mem_50, mem_20, mem_80 = _test_memory(
fwd,
quantiles=QUANTILES,
)
elif mode == "backward":
do = torch.randn_like(x)
y = fwd()
mem_50, mem_20, mem_80 = _test_memory(
lambda: y.backward(do, retain_graph=True),
quantiles=QUANTILES,
)
else:
mem_50, mem_20, mem_80 = _test_memory(
full,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
model = MODEL_REGISTRY[args.model] if args.model else DEFAULT_MODEL_CONFIG
total_memory_gb = get_total_gpu_memory()

probe_seq_len = 1024
probe_input = SingleBenchmarkRunInput(
x=probe_seq_len,
kernel_provider="huggingface",
extra_benchmark_config={
"bsz": 1,
"hidden_size": model.hidden_size,
"intermediate_size": model.intermediate_size,
"hidden_act": "gelu_pytorch_tanh",
"dtype": model.dtype,
},
)
probe_x, probe_layer = _setup_geglu(probe_input)
kernel_bpt = estimate_kernel_bytes_per_token(
kernel_fn=lambda: probe_layer(probe_x),
num_tokens=probe_seq_len,
)
del probe_x, probe_layer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe input setup and cleanup should be done in estimate_kernel_bytes_per_token, so developers don't have to worry about them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion!



if __name__ == "__main__":
args = parse_benchmark_script_args()
shape = compute_benchmark_shape(
total_memory_gb,
model,
kernel_bytes_per_token=kernel_bpt,
)

common_configs = {
"kernel_name": "geglu",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, 14)],
"x_values": [2**i for i in range(10, int(math.log2(shape.seq_len)) + 1)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"bsz": 8,
"hidden_size": 4096,
"intermediate_size": 11008,
"bsz": shape.batch_size,
"hidden_size": model.hidden_size,
"intermediate_size": model.intermediate_size,
"hidden_act": "gelu_pytorch_tanh",
"dtype": torch.bfloat16,
"dtype": model.dtype,
}
],
"overwrite": args.overwrite,
Expand Down
Loading
Loading