-
Notifications
You must be signed in to change notification settings - Fork 501
[Test]: Refactor benchmark_geglu with standardized model configs #1116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,163 +1,106 @@ | ||
| import math | ||
|
|
||
| import torch | ||
| import triton | ||
|
|
||
| from benchmark_model_configs import DEFAULT_MODEL_CONFIG | ||
| from benchmark_model_configs import MODEL_REGISTRY | ||
| from benchmark_model_configs import compute_benchmark_shape | ||
| from benchmark_model_configs import estimate_kernel_bytes_per_token | ||
| from transformers.models.llama.configuration_llama import LlamaConfig | ||
| from transformers.models.llama.modeling_llama import LlamaMLP | ||
| from utils import QUANTILES | ||
| from utils import SingleBenchmarkRunInput | ||
| from utils import SingleBenchmarkRunOutput | ||
| from utils import _test_memory | ||
| from utils import parse_benchmark_script_args | ||
| from utils import run_benchmarks | ||
| from utils import run_memory_benchmark | ||
| from utils import run_speed_benchmark | ||
|
|
||
| from liger_kernel.transformers.geglu import LigerGEGLUMLP | ||
| from liger_kernel.utils import infer_device | ||
| from liger_kernel.utils import get_total_gpu_memory | ||
|
|
||
| device = infer_device() | ||
|
|
||
|
|
||
| def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| seq_len = input.x | ||
| bsz = input.extra_benchmark_config["bsz"] | ||
| hidden_size = input.extra_benchmark_config["hidden_size"] | ||
| intermediate_size = input.extra_benchmark_config["intermediate_size"] | ||
| hidden_act = input.extra_benchmark_config["hidden_act"] | ||
| dtype = input.extra_benchmark_config["dtype"] | ||
| provider = input.kernel_provider | ||
| mode = input.kernel_operation_mode | ||
|
|
||
| def _setup_geglu(input: SingleBenchmarkRunInput): | ||
| """Create input tensor and GEGLU layer from benchmark config.""" | ||
| cfg = input.extra_benchmark_config | ||
| llama_config = LlamaConfig( | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| hidden_act=hidden_act, | ||
| hidden_size=cfg["hidden_size"], | ||
| intermediate_size=cfg["intermediate_size"], | ||
| hidden_act=cfg["hidden_act"], | ||
| ) | ||
|
|
||
| x_shape = (bsz, seq_len, hidden_size) | ||
|
|
||
| # initialize input | ||
| x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) | ||
|
|
||
| if provider == "liger": | ||
| layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) | ||
| elif provider == "huggingface": | ||
| layer = LlamaMLP(config=llama_config).to(device).to(dtype) | ||
| else: | ||
| raise ValueError(f"Invalid provider: {provider} for GEGLU") | ||
|
|
||
| def fwd(): | ||
| return layer(x) | ||
|
|
||
| if mode == "forward": | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| fwd, | ||
| grad_to_none=[x], | ||
| rep=10, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| elif mode == "backward": | ||
| do = torch.randn_like(x) | ||
| y = fwd() | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| lambda: y.backward(do, retain_graph=True), | ||
| grad_to_none=[x], | ||
| rep=10, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| x = torch.randn( | ||
| cfg["bsz"], | ||
| input.x, | ||
| cfg["hidden_size"], | ||
| device=device, | ||
| dtype=cfg["dtype"], | ||
| requires_grad=True, | ||
| ) | ||
| if input.kernel_provider == "liger": | ||
| layer = LigerGEGLUMLP(config=llama_config).to(device).to(cfg["dtype"]) | ||
| elif input.kernel_provider == "huggingface": | ||
| layer = LlamaMLP(config=llama_config).to(device).to(cfg["dtype"]) | ||
| else: | ||
| raise ValueError(f"Invalid provider: {input.kernel_provider} for GEGLU") | ||
| return x, layer | ||
|
|
||
| def full(): | ||
| y = fwd() | ||
| y.backward(torch.randn_like(y), retain_graph=True) | ||
|
|
||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| full, | ||
| grad_to_none=[x], | ||
| rep=10, | ||
| quantiles=QUANTILES, | ||
| ) | ||
|
|
||
| return SingleBenchmarkRunOutput( | ||
| y_20=ms_20, | ||
| y_50=ms_50, | ||
| y_80=ms_80, | ||
| ) | ||
|
|
||
| def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| x, layer = _setup_geglu(input) | ||
| return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) | ||
|
|
||
|
|
||
| def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| seq_len = input.x | ||
| bsz = input.extra_benchmark_config["bsz"] | ||
| hidden_size = input.extra_benchmark_config["hidden_size"] | ||
| intermediate_size = input.extra_benchmark_config["intermediate_size"] | ||
| hidden_act = input.extra_benchmark_config["hidden_act"] | ||
| dtype = input.extra_benchmark_config["dtype"] | ||
| provider = input.kernel_provider | ||
| mode = input.kernel_operation_mode | ||
| x, layer = _setup_geglu(input) | ||
| return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) | ||
|
|
||
| llama_config = LlamaConfig( | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| hidden_act=hidden_act, | ||
| ) | ||
|
|
||
| x_shape = (bsz, seq_len, hidden_size) | ||
| # initialize input | ||
| x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) | ||
| if __name__ == "__main__": | ||
| args = parse_benchmark_script_args() | ||
|
|
||
| if provider == "liger": | ||
| layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) | ||
| elif provider == "huggingface": | ||
| layer = LlamaMLP(config=llama_config).to(device).to(dtype) | ||
| else: | ||
| raise ValueError(f"Invalid provider: {provider} for GEGLU") | ||
|
|
||
| def fwd(): | ||
| return layer(x) | ||
|
|
||
| def full(): | ||
| y = fwd() | ||
| y.backward(torch.randn_like(y), retain_graph=True) | ||
|
|
||
| if mode == "forward": | ||
| mem_50, mem_20, mem_80 = _test_memory( | ||
| fwd, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| elif mode == "backward": | ||
| do = torch.randn_like(x) | ||
| y = fwd() | ||
| mem_50, mem_20, mem_80 = _test_memory( | ||
| lambda: y.backward(do, retain_graph=True), | ||
| quantiles=QUANTILES, | ||
| ) | ||
| else: | ||
| mem_50, mem_20, mem_80 = _test_memory( | ||
| full, | ||
| quantiles=QUANTILES, | ||
| ) | ||
|
|
||
| return SingleBenchmarkRunOutput( | ||
| y_20=mem_20, | ||
| y_50=mem_50, | ||
| y_80=mem_80, | ||
| model = MODEL_REGISTRY[args.model] if args.model else DEFAULT_MODEL_CONFIG | ||
| total_memory_gb = get_total_gpu_memory() | ||
|
|
||
| probe_seq_len = 1024 | ||
| probe_input = SingleBenchmarkRunInput( | ||
| x=probe_seq_len, | ||
| kernel_provider="huggingface", | ||
| extra_benchmark_config={ | ||
| "bsz": 1, | ||
| "hidden_size": model.hidden_size, | ||
| "intermediate_size": model.intermediate_size, | ||
| "hidden_act": "gelu_pytorch_tanh", | ||
| "dtype": model.dtype, | ||
| }, | ||
| ) | ||
| probe_x, probe_layer = _setup_geglu(probe_input) | ||
| kernel_bpt = estimate_kernel_bytes_per_token( | ||
| kernel_fn=lambda: probe_layer(probe_x), | ||
| num_tokens=probe_seq_len, | ||
| ) | ||
| del probe_x, probe_layer | ||
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_benchmark_script_args() | ||
| shape = compute_benchmark_shape( | ||
| total_memory_gb, | ||
| model, | ||
| kernel_bytes_per_token=kernel_bpt, | ||
| ) | ||
|
|
||
| common_configs = { | ||
| "kernel_name": "geglu", | ||
| "x_name": "T", | ||
| "x_label": "sequence length", | ||
| "x_values": [2**i for i in range(10, 14)], | ||
| "x_values": [2**i for i in range(10, int(math.log2(shape.seq_len)) + 1)], | ||
| "kernel_providers": ["liger", "huggingface"], | ||
| "extra_benchmark_configs": [ | ||
| { | ||
| "bsz": 8, | ||
| "hidden_size": 4096, | ||
| "intermediate_size": 11008, | ||
| "bsz": shape.batch_size, | ||
| "hidden_size": model.hidden_size, | ||
| "intermediate_size": model.intermediate_size, | ||
| "hidden_act": "gelu_pytorch_tanh", | ||
| "dtype": torch.bfloat16, | ||
| "dtype": model.dtype, | ||
| } | ||
| ], | ||
| "overwrite": args.overwrite, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.