Skip to content

Commit

Permalink
Add thunder benchmarks (#3394)
Browse files Browse the repository at this point in the history
Adds `thunder` as an additional executor to the baseline benchmarks and
the corresponding `thunder.jit` function.
The following benchmarks do not have `thunder` benchmark:
1. `instancenorm`: Unsupported operator in Thunder
2. `test_gelu_backward_reduction.py`: `.backward` call is not supported
within Thunder definitions. @IvanYashchuk has suggested using explicit
backward computation for this case.

Issue #2718
  • Loading branch information
Priya2698 authored Jan 22, 2025
1 parent 92fdf42 commit 6a4f050
Show file tree
Hide file tree
Showing 32 changed files with 126 additions and 65 deletions.
2 changes: 1 addition & 1 deletion benchmarks/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def pytest_configure(config):
def pytest_collection_modifyitems(session, config, items):
"""
The baseline benchmarks use `executor` parameter with
values ["eager", "torchcompile", "thunder"] that are optionally
values ["eager", "torchcompile", "thunder", "thunder-torchcompile"] that are optionally
run using `--benchmark-{executor}` flag. They are skipped by
default.
"""
Expand Down
9 changes: 7 additions & 2 deletions benchmarks/python/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import thunder
from thunder.executors.nvfuserex import nvfuserex


# These variables can be overwritten through CLI commands
# --benchmark-rounds=rounds --benchmark-warmup-rounds=warmup_rounds
# --benchmark-num-inputs=num_inputs
Expand All @@ -23,6 +22,8 @@
L2_CACHE_SIZE = DEVICE_PROPERTIES["gpu_l2_bytes"]
PEAK_BANDWIDTH_GBPS = DEVICE_PROPERTIES["gpu_peak_bandwidth_gbps"]

DEFAULT_EXECUTORS = ["eager", "torchcompile", "thunder"]


def clear_l2_cache() -> None:
"""
Expand All @@ -44,7 +45,8 @@ def clear_dynamo_cache() -> None:


# Backward function for torch baseline benchmarks.
def unary_bwd_torch(inputs: List): # [output, grad_out]
# The first two inputs are expected to be out and grad_out. The remaining are inputs of the forward pass used to clear grad between subsequent runs to avoid grad accumulation. See setup() in run_benchmark().
def unary_bwd_torch(inputs: List): # [output, grad_out, fwd_inputs]
inputs[0].backward(inputs[1], retain_graph=True)


Expand Down Expand Up @@ -329,6 +331,9 @@ def run_benchmark(
def setup():
clear_l2_cache()
if device == "cuda":
for inp in inputs:
if isinstance(inp, torch.Tensor):
inp.grad = None
return [inputs], {}

# Device = 'host'
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/python/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,6 @@ def norm_bwd_baseline_benchmark(
run_benchmark(
benchmark,
unary_bwd_torch,
[outputs, grads],
[outputs, grads, *fwd_inputs],
iobytes=norm_bwd_iobytes(size, dtype, norm),
)
3 changes: 2 additions & 1 deletion benchmarks/python/test_batchnorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES
from .normalization import norm_bwd_nvf_benchmark, norm_bwd_baseline_benchmark
from .core import DEFAULT_EXECUTORS


@pytest.mark.parametrize("size", generate_input_sizes(dims=4))
Expand All @@ -31,7 +32,7 @@ def test_batchnorm_bwd_nvf_benchmark(
)


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=4))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("channels_last", [True, False])
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/python/test_batchnorm_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES
from .normalization import norm_fwd_nvf_benchmark, norm_fwd_baseline_benchmark
from .core import DEFAULT_EXECUTORS


@pytest.mark.parametrize("size", generate_input_sizes(dims=4))
Expand All @@ -31,7 +32,7 @@ def test_batchnorm_fwd_nvf_benchmark(
)


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=4))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("channels_last", [True, False])
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/python/test_broadcast_add_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache, with_executor
from .core import run_benchmark, clear_dynamo_cache, with_executor, DEFAULT_EXECUTORS
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES

Expand Down Expand Up @@ -88,7 +88,7 @@ def test_bcast_add_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [bias, x])


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("bcast_axis", [0, 1], ids=["outer", "inner"])
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/python/test_dropout_layernorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
unary_bwd_torch,
compute_total_iobytes,
with_executor,
DEFAULT_EXECUTORS,
)
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES
Expand Down Expand Up @@ -191,7 +192,7 @@ def test_dropout_layernorm_bwd_nvf_benchmark(
)


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_dropout_layernorm_bwd_baseline_benchmark(
Expand Down Expand Up @@ -219,6 +220,6 @@ def test_dropout_layernorm_bwd_baseline_benchmark(
run_benchmark(
benchmark,
unary_bwd_torch,
[outputs, grads],
[outputs, grads, *fwd_inputs],
iobytes=dropout_layernorm_bwd_iobytes(size, dtype),
)
3 changes: 2 additions & 1 deletion benchmarks/python/test_dropout_layernorm_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
clear_dynamo_cache,
compute_total_iobytes,
with_executor,
DEFAULT_EXECUTORS,
)
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES
Expand Down Expand Up @@ -151,7 +152,7 @@ def test_dropout_layernorm_fwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, inputs)


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_dropout_layernorm_fwd_baseline_benchmark(
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/python/test_dropout_rmsnorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
unary_bwd_torch,
compute_total_iobytes,
with_executor,
DEFAULT_EXECUTORS,
)
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES
Expand Down Expand Up @@ -171,7 +172,7 @@ def test_dropout_rmsnorm_bwd_nvf_benchmark(
)


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_dropout_rmsnorm_bwd_baseline_benchmark(
Expand All @@ -195,6 +196,6 @@ def test_dropout_rmsnorm_bwd_baseline_benchmark(
run_benchmark(
benchmark,
unary_bwd_torch,
[outputs, grads],
[outputs, grads, *fwd_inputs],
iobytes=dropout_rmsnorm_bwd_iobytes(size, dtype),
)
3 changes: 2 additions & 1 deletion benchmarks/python/test_dropout_rmsnorm_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
clear_dynamo_cache,
compute_total_iobytes,
with_executor,
DEFAULT_EXECUTORS,
)
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES
Expand Down Expand Up @@ -141,7 +142,7 @@ def test_dropout_rmsnorm_fwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [input1, input2, weights])


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_dropout_rmsnorm_fwd_baseline_benchmark(
Expand Down
12 changes: 9 additions & 3 deletions benchmarks/python/test_gelu_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch, with_executor
from .core import (
run_benchmark,
clear_dynamo_cache,
unary_bwd_torch,
with_executor,
DEFAULT_EXECUTORS,
)
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES
import numpy as np
Expand Down Expand Up @@ -89,7 +95,7 @@ def test_gelu_bwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [inputs, grads, bias])


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_gelu_bwd_baseline_benchmark(
Expand All @@ -111,6 +117,6 @@ def test_gelu_bwd_baseline_benchmark(
run_benchmark(
benchmark,
unary_bwd_torch,
[outputs, grads],
[outputs, grads, *fwd_inputs],
iobytes=gelu_bwd_iobytes(size, dtype),
)
4 changes: 2 additions & 2 deletions benchmarks/python/test_gelu_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache, with_executor
from .core import run_benchmark, clear_dynamo_cache, with_executor, DEFAULT_EXECUTORS
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES
from .torch_ops import gelu
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_gelu_fwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, inputs)


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_gelu_fwd_baseline_benchmark(
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/python/test_groupnorm_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache, with_executor
from .core import run_benchmark, clear_dynamo_cache, with_executor, DEFAULT_EXECUTORS
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES

Expand Down Expand Up @@ -126,7 +126,7 @@ def test_groupnorm_fwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [x, weight, bias])


@pytest.mark.parametrize("executor", ["eager", "torchcompile", "thunder"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=4))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_groupnorm_fwd_baseline_benchmark(
Expand Down
12 changes: 9 additions & 3 deletions benchmarks/python/test_huggingface_attn_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch, with_executor
from .core import (
run_benchmark,
clear_dynamo_cache,
unary_bwd_torch,
with_executor,
DEFAULT_EXECUTORS,
)
import torch
from .global_params import generate_attn_inputs, FLOAT_DTYPES, PROMOTE_DTYPES
from .torch_ops import huggingface_attn
Expand Down Expand Up @@ -108,7 +114,7 @@ def test_huggingface_attn_bwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [grads, attn, dropout_mask])


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_attn_inputs())
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_huggingface_attn_bwd_baseline_benchmark(
Expand Down Expand Up @@ -138,6 +144,6 @@ def test_huggingface_attn_bwd_baseline_benchmark(
run_benchmark(
benchmark,
unary_bwd_torch,
[outputs, grads],
[outputs, grads, *fwd_inputs],
iobytes=huggingface_attn_bwd_iobytes(size, dtype),
)
4 changes: 2 additions & 2 deletions benchmarks/python/test_huggingface_attn_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache, with_executor
from .core import run_benchmark, clear_dynamo_cache, with_executor, DEFAULT_EXECUTORS
import torch
from .global_params import generate_attn_inputs, FLOAT_DTYPES, PROMOTE_DTYPES
from .torch_ops import huggingface_attn
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_huggingface_attn_fwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [attention_mask, inputs])


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_attn_inputs())
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_huggingface_attn_fwd_baseline_benchmark(
Expand Down
12 changes: 9 additions & 3 deletions benchmarks/python/test_layernorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch, with_executor
from .core import (
run_benchmark,
clear_dynamo_cache,
unary_bwd_torch,
with_executor,
DEFAULT_EXECUTORS,
)
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES
import numpy as np
Expand Down Expand Up @@ -147,7 +153,7 @@ def test_layernorm_bwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [inputs, grads, mean, invstd, weights])


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_layernorm_bwd_baseline_benchmark(
Expand All @@ -172,6 +178,6 @@ def test_layernorm_bwd_baseline_benchmark(
run_benchmark(
benchmark,
unary_bwd_torch,
[outputs, grads],
[outputs, grads, *fwd_inputs],
iobytes=layernorm_bwd_iobytes(size, dtype),
)
4 changes: 2 additions & 2 deletions benchmarks/python/test_layernorm_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache, with_executor
from .core import run_benchmark, clear_dynamo_cache, with_executor, DEFAULT_EXECUTORS
import torch
from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES
import numpy as np
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_layernorm_fwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, inputs)


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_layernorm_fwd_baseline_benchmark(
Expand Down
12 changes: 9 additions & 3 deletions benchmarks/python/test_nanogpt_attn_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch, with_executor
from .core import (
run_benchmark,
clear_dynamo_cache,
unary_bwd_torch,
with_executor,
DEFAULT_EXECUTORS,
)
import torch
from .global_params import generate_attn_inputs, FLOAT_DTYPES, PROMOTE_DTYPES
from .torch_ops import nanogpt_attn
Expand Down Expand Up @@ -125,7 +131,7 @@ def test_nanogpt_attn_bwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [grads, attn, dropout_mask, bias_mask])


@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_attn_inputs())
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_nanogpt_attn_bwd_baseline_benchmark(
Expand Down Expand Up @@ -156,6 +162,6 @@ def test_nanogpt_attn_bwd_baseline_benchmark(
run_benchmark(
benchmark,
unary_bwd_torch,
[outputs, grads],
[outputs, grads, *fwd_inputs],
iobytes=nanogpt_attn_bwd_iobytes(size, dtype),
)
Loading

0 comments on commit 6a4f050

Please sign in to comment.