Skip to content

Commit

Permalink
Add a few benchmarks for concatenation.
Browse files Browse the repository at this point in the history
  • Loading branch information
tfogal committed Jan 23, 2025
1 parent 6a4f050 commit 2dd8309
Showing 1 changed file with 284 additions and 0 deletions.
284 changes: 284 additions & 0 deletions benchmarks/python/test_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
import pytest
import torch
from nvfuser import FusionDefinition, DataType
import core


def phi3_cat_fusion_1(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(shape=[1, 48, 2048], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T1 = fd.ops.permute(T0, dims=[0, 2, 1])
T2 = fd.ops.cat([T1, T1], dim=-1, manual_padding=0)
T3 = fd.ops.cos(T2)
T4 = fd.ops.sin(T2)
S5 = fd.define_scalar(1.19024, dtype=DataType.Double)
T6 = fd.ops.mul(T3, S5)
S7 = fd.define_scalar(1.19024, dtype=DataType.Double)
T8 = fd.ops.mul(T4, S7)
T9 = fd.ops.cast(T6, dtype=DataType.BFloat16)
T10 = fd.ops.cast(T8, dtype=DataType.BFloat16)
fd.add_output(T9)
fd.add_output(T10)


def phi3_cat_fusion_2(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(shape=[2048, 9216], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
S1 = fd.define_scalar(None, dtype=DataType.Int)
S2 = fd.define_scalar(None, dtype=DataType.Int)
T3 = fd.define_tensor(shape=[1, 2048, 9216], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T4 = fd.define_tensor(shape=[1, 2048, 96], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 0, 1])
T5 = fd.define_tensor(shape=[1, 2048, 96], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 0, 1])
T10 = fd.ops.reshape(T0, new_shape=[1, 2048, 9216])
T11 = fd.ops.cast(T10, dtype=DataType.Float)
S12 = fd.define_scalar(0.00000, dtype=DataType.Double)
S13 = fd.define_scalar(1.00000, dtype=DataType.Double)
S14 = fd.define_scalar(1, dtype=DataType.Int)
S15 = fd.define_scalar(2048, dtype=DataType.Int)
S16 = fd.define_scalar(9216, dtype=DataType.Int)
T18 = fd.ops.uniform(S12, S13, shape=[S14, S15, S16], rng_seed=S2, rng_offset=S1, dtype=DataType.BFloat16)
S19 = fd.define_scalar(4.00000, dtype=DataType.Double)
T20 = fd.ops.mul(T11, S19)
S21 = fd.define_scalar(0.900000, dtype=DataType.Double)
T22 = fd.ops.lt(T18, S21)
T23 = fd.ops.cast(T22, dtype=DataType.Float)
T24 = fd.ops.mul(T20, T23)
S25 = fd.define_scalar(1.11111, dtype=DataType.Double)
T26 = fd.ops.mul(T24, S25)
T27 = fd.ops.cast(T3, dtype=DataType.Float)
T28 = fd.ops.add(T27, T26)
T29 = fd.ops.cast(T28, dtype=DataType.BFloat16)
T42 = fd.ops.slice(T29, start_indices=[0, 0, 3072], end_indices=[1, 2048, 6144], strides=[1, 1, 1], manual_normalization=0)
T55 = fd.ops.slice(T29, start_indices=[0, 0, 0], end_indices=[1, 2048, 3072], strides=[1, 1, 1], manual_normalization=0)
T61 = fd.ops.reshape(T42, new_shape=[1, 2048, 32, 96])
T67 = fd.ops.reshape(T55, new_shape=[1, 2048, 32, 96])
T68 = fd.ops.permute(T61, dims=[0, 2, 1, 3])
T69 = fd.ops.permute(T67, dims=[0, 2, 1, 3])
T85 = fd.ops.slice(T68, start_indices=[0, 0, 0, 48], end_indices=[1, 32, 2048, 96], strides=[1, 1, 1, 1], manual_normalization=0)
T101 = fd.ops.slice(T69, start_indices=[0, 0, 0, 48], end_indices=[1, 32, 2048, 96], strides=[1, 1, 1, 1], manual_normalization=0)
T102 = fd.ops.cast(T85, dtype=DataType.Float)
T103 = fd.ops.cast(T101, dtype=DataType.Float)
T104 = fd.ops.neg(T102)
T105 = fd.ops.neg(T103)
T111 = fd.ops.broadcast_in_dim(T4, shape=[1, 1, 2048, 96], broadcast_dims=[0, 2, 3])
T127 = fd.ops.slice(T68, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 2048, 48], strides=[1, 1, 1, 1], manual_normalization=0)
T128 = fd.ops.cast(T104, dtype=DataType.BFloat16)
T134 = fd.ops.broadcast_in_dim(T5, shape=[1, 1, 2048, 96], broadcast_dims=[0, 2, 3])
T150 = fd.ops.slice(T69, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 2048, 48], strides=[1, 1, 1, 1], manual_normalization=0)
T151 = fd.ops.cast(T105, dtype=DataType.BFloat16)
T157 = fd.ops.broadcast_in_dim(T111, shape=[1, 32, 2048, 96], broadcast_dims=[0, 1, 2, 3])
T158 = fd.ops.cat([T128, T127], dim=-1, manual_padding=0)
T164 = fd.ops.broadcast_in_dim(T134, shape=[1, 32, 2048, 96], broadcast_dims=[0, 1, 2, 3])
T165 = fd.ops.cat([T151, T150], dim=-1, manual_padding=0)
T166 = fd.ops.cast(T157, dtype=DataType.Float)
T167 = fd.ops.cast(T158, dtype=DataType.Float)
T168 = fd.ops.cast(T164, dtype=DataType.Float)
T169 = fd.ops.cast(T68, dtype=DataType.Float)
T170 = fd.ops.cast(T165, dtype=DataType.Float)
T171 = fd.ops.cast(T69, dtype=DataType.Float)
T172 = fd.ops.mul(T167, T166)
T173 = fd.ops.mul(T169, T168)
T174 = fd.ops.mul(T170, T166)
T175 = fd.ops.mul(T171, T168)
T188 = fd.ops.slice(T29, start_indices=[0, 0, 6144], end_indices=[1, 2048, 9216], strides=[1, 1, 1], manual_normalization=0)
T194 = fd.ops.reshape(T188, new_shape=[1, 2048, 32, 96])
T195 = fd.ops.add(T173, T172)
T196 = fd.ops.add(T175, T174)
T197 = fd.ops.permute(T194, dims=[0, 2, 1, 3])
T198 = fd.ops.cast(T195, dtype=DataType.BFloat16)
T199 = fd.ops.cast(T196, dtype=DataType.BFloat16)
T200 = fd.ops.stride_order(T197, stride_order=[3, 2, 1, 0])
T201 = fd.ops.stride_order(T198, stride_order=[3, 2, 1, 0])
T202 = fd.ops.stride_order(T199, stride_order=[3, 2, 1, 0])
fd.add_output(T197)
fd.add_output(T198)
fd.add_output(T200)
fd.add_output(T201)
fd.add_output(T202)


def qwen2_cat_fusion_1(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(shape=[1, 64, 2048], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T1 = fd.define_tensor(shape=[1, 2048, 3584], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T2 = fd.define_tensor(shape=[3584], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T3 = fd.ops.permute(T0, dims=[0, 2, 1])
T4 = fd.ops.cat([T3, T3], dim=-1, manual_padding=0)
T5 = fd.ops.cos(T4)
T6 = fd.ops.sin(T4)
S7 = fd.define_scalar(1.00000, dtype=DataType.Double)
T8 = fd.ops.mul(T5, S7)
S9 = fd.define_scalar(1.00000, dtype=DataType.Double)
T10 = fd.ops.mul(T6, S9)
T11 = fd.ops.cast(T8, dtype=DataType.BFloat16)
T12 = fd.ops.cast(T10, dtype=DataType.BFloat16)
T13 = fd.ops.cast(T1, dtype=DataType.Float)
S14 = fd.define_scalar(2.00000, dtype=DataType.Double)
T15 = fd.ops.pow(T13, S14)
T16 = fd.ops.sum(T15, dims=[2], keepdim=False, dtype=DataType.Null)
T21 = fd.ops.broadcast_in_dim(T16, shape=[1, 2048, 1], broadcast_dims=[0, 1])
S22 = fd.define_scalar(3584.00, dtype=DataType.Double)
S23 = fd.ops.reciprocal(S22)
T24 = fd.ops.mul(T21, S23)
S25 = fd.define_scalar(1.00000e-06, dtype=DataType.Double)
T26 = fd.ops.add(T24, S25)
T27 = fd.ops.rsqrt(T26)
T32 = fd.ops.broadcast_in_dim(T27, shape=[1, 2048, 3584], broadcast_dims=[0, 1, 2])
T33 = fd.ops.mul(T13, T32)
T38 = fd.ops.broadcast_in_dim(T2, shape=[1, 2048, 3584], broadcast_dims=[2])
T39 = fd.ops.cast(T38, dtype=DataType.Float)
T40 = fd.ops.mul(T39, T33)
T41 = fd.ops.cast(T40, dtype=DataType.BFloat16)
fd.add_output(T11)
fd.add_output(T12)
fd.add_output(T41)


def qwen2_cat_fusion_2(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(shape=[2048, 512], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
S1 = fd.define_scalar(None, dtype=DataType.Int)
S2 = fd.define_scalar(None, dtype=DataType.Int)
T3 = fd.define_tensor(shape=[1, 4, 2048, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
T4 = fd.define_tensor(shape=[1, 2048, 128], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 0, 1])
T5 = fd.define_tensor(shape=[1, 2048, 128], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 0, 1])
T6 = fd.define_tensor(shape=[1, 28, 2048, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
T7 = fd.define_tensor(shape=[1, 2048, 512], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T12 = fd.ops.reshape(T0, new_shape=[1, 2048, 512])
T13 = fd.ops.cast(T12, dtype=DataType.Float)
S14 = fd.define_scalar(0.00000, dtype=DataType.Double)
S15 = fd.define_scalar(1.00000, dtype=DataType.Double)
S16 = fd.define_scalar(1, dtype=DataType.Int)
S17 = fd.define_scalar(2048, dtype=DataType.Int)
S18 = fd.define_scalar(512, dtype=DataType.Int)
T20 = fd.ops.uniform(S14, S15, shape=[S16, S17, S18], rng_seed=S2, rng_offset=S1, dtype=DataType.BFloat16)
S21 = fd.define_scalar(4.00000, dtype=DataType.Double)
T22 = fd.ops.mul(T13, S21)
S23 = fd.define_scalar(0.900000, dtype=DataType.Double)
T24 = fd.ops.lt(T20, S23)
T25 = fd.ops.cast(T24, dtype=DataType.Float)
T41 = fd.ops.slice(T3, start_indices=[0, 0, 0, 64], end_indices=[1, 4, 2048, 128], strides=[1, 1, 1, 1], manual_normalization=0)
T42 = fd.ops.mul(T22, T25)
T43 = fd.ops.cast(T41, dtype=DataType.Float)
T44 = fd.ops.neg(T43)
T50 = fd.ops.broadcast_in_dim(T4, shape=[1, 1, 2048, 128], broadcast_dims=[0, 2, 3])
T66 = fd.ops.slice(T3, start_indices=[0, 0, 0, 0], end_indices=[1, 4, 2048, 64], strides=[1, 1, 1, 1], manual_normalization=0)
T67 = fd.ops.cast(T44, dtype=DataType.BFloat16)
T73 = fd.ops.broadcast_in_dim(T5, shape=[1, 1, 2048, 128], broadcast_dims=[0, 2, 3])
T89 = fd.ops.slice(T6, start_indices=[0, 0, 0, 64], end_indices=[1, 28, 2048, 128], strides=[1, 1, 1, 1], manual_normalization=0)
S90 = fd.define_scalar(1.11111, dtype=DataType.Double)
T91 = fd.ops.mul(T42, S90)
T97 = fd.ops.broadcast_in_dim(T50, shape=[1, 4, 2048, 128], broadcast_dims=[0, 1, 2, 3])
T98 = fd.ops.cat([T67, T66], dim=-1, manual_padding=0)
T104 = fd.ops.broadcast_in_dim(T73, shape=[1, 4, 2048, 128], broadcast_dims=[0, 1, 2, 3])
T105 = fd.ops.cast(T89, dtype=DataType.Float)
T106 = fd.ops.cast(T97, dtype=DataType.Float)
T107 = fd.ops.cast(T98, dtype=DataType.Float)
T108 = fd.ops.cast(T104, dtype=DataType.Float)
T109 = fd.ops.cast(T3, dtype=DataType.Float)
T110 = fd.ops.neg(T105)
T111 = fd.ops.cast(T7, dtype=DataType.Float)
T112 = fd.ops.mul(T107, T106)
T113 = fd.ops.mul(T109, T108)
T129 = fd.ops.slice(T6, start_indices=[0, 0, 0, 0], end_indices=[1, 28, 2048, 64], strides=[1, 1, 1, 1], manual_normalization=0)
T130 = fd.ops.cast(T110, dtype=DataType.BFloat16)
T131 = fd.ops.add(T111, T91)
T137 = fd.ops.broadcast_in_dim(T50, shape=[1, 28, 2048, 128], broadcast_dims=[0, 1, 2, 3])
T138 = fd.ops.cat([T130, T129], dim=-1, manual_padding=0)
T144 = fd.ops.broadcast_in_dim(T73, shape=[1, 28, 2048, 128], broadcast_dims=[0, 1, 2, 3])
T145 = fd.ops.cast(T131, dtype=DataType.BFloat16)
T146 = fd.ops.cast(T137, dtype=DataType.Float)
T147 = fd.ops.cast(T138, dtype=DataType.Float)
T148 = fd.ops.cast(T144, dtype=DataType.Float)
T149 = fd.ops.cast(T6, dtype=DataType.Float)
T155 = fd.ops.reshape(T145, new_shape=[1, 2048, 4, 128])
T156 = fd.ops.add(T113, T112)
T157 = fd.ops.mul(T147, T146)
T158 = fd.ops.mul(T149, T148)
T159 = fd.ops.permute(T155, dims=[0, 2, 1, 3])
T160 = fd.ops.cast(T156, dtype=DataType.BFloat16)
T167 = fd.ops.broadcast_in_dim(T159, shape=[1, 4, 1, 2048, 128], broadcast_dims=[0, 1, 3, 4])
T174 = fd.ops.broadcast_in_dim(T160, shape=[1, 4, 1, 2048, 128], broadcast_dims=[0, 1, 3, 4])
T181 = fd.ops.broadcast_in_dim(T167, shape=[1, 4, 7, 2048, 128], broadcast_dims=[0, 1, 2, 3, 4])
T188 = fd.ops.broadcast_in_dim(T174, shape=[1, 4, 7, 2048, 128], broadcast_dims=[0, 1, 2, 3, 4])
T189 = fd.ops.add(T158, T157)
T195 = fd.ops.reshape(T181, new_shape=[1, 28, 2048, 128])
T201 = fd.ops.reshape(T188, new_shape=[1, 28, 2048, 128])
T202 = fd.ops.cast(T189, dtype=DataType.BFloat16)
T203 = fd.ops.stride_order(T195, stride_order=[3, 2, 1, 0])
T204 = fd.ops.stride_order(T201, stride_order=[3, 2, 1, 0])
T205 = fd.ops.stride_order(T202, stride_order=[3, 2, 1, 0])
fd.add_output(T159)
fd.add_output(T160)
fd.add_output(T203)
fd.add_output(T204)
fd.add_output(T205)


def test_cat_phi3_v1(benchmark):
inputs = [
torch.randn((1, 48, 2048), dtype=torch.float32, device='cuda:0')
]

with FusionDefinition() as fd:
phi3_cat_fusion_1(fd)

def benchmark_fn(inputs):
fd.execute(inputs)

core.run_benchmark(benchmark, benchmark_fn, inputs)


def test_cat_phi3_v2(benchmark):
inputs = [
torch.testing.make_tensor((2048, 9216), dtype=torch.bfloat16, device='cuda:0'),
13284,
7808972590298790,
torch.testing.make_tensor((1, 2048, 9216), dtype=torch.bfloat16, device='cuda:0'),
torch.randn(196608, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 2048, 96), (196608, 1, 2048)),
torch.randn(196608, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 2048, 96), (196608, 1, 2048)),
]
with FusionDefinition() as fd:
phi3_cat_fusion_2(fd)

def benchmark_fn(inputs):
fd.execute(inputs)

core.run_benchmark(benchmark, benchmark_fn, inputs)


# Note that this is also basically the same fusion the Mistral-Nemo network.
def test_cat_qwen2_v1(benchmark):
inputs = [
torch.testing.make_tensor((1, 64, 2048), dtype=torch.float32, device='cuda:0'),
torch.testing.make_tensor((1, 2048, 3584), dtype=torch.bfloat16, device='cuda:0'),
torch.testing.make_tensor((3584,), dtype=torch.bfloat16, device='cuda:0'),
]
with FusionDefinition() as fd:
qwen2_cat_fusion_1(fd)

def benchmark_fn(inputs):
fd.execute(inputs)

core.run_benchmark(benchmark, benchmark_fn, inputs)


@pytest.mark.skip("issue 3740")
def test_cat_qwen2_v2(benchmark):
inputs = [
torch.testing.make_tensor((2048, 512), dtype=torch.bfloat16, device='cuda:0'),
25546,
1400552702872758,
torch.randn(1048576, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 4, 2048, 128), (1048576, 128, 512, 1)),
torch.randn(262144, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 2048, 128), (262144, 1, 2048)),
torch.randn(262144, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 2048, 128), (262144, 1, 2048)),
torch.randn(7340032, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 28, 2048, 128), (7340032, 128, 3584, 1)),
torch.testing.make_tensor((1, 2048, 512), dtype=torch.bfloat16, device='cuda:0'),
]

with FusionDefinition() as fd:
qwen2_cat_fusion_2(fd)

def benchmark_fn(inputs):
fd.execute(inputs)

core.run_benchmark(benchmark, benchmark_fn, inputs)

0 comments on commit 2dd8309

Please sign in to comment.