diff --git a/3rdparty/cutlass b/3rdparty/cutlass index a3bcc6981d5d..dceabd4c5a2a 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit a3bcc6981d5dad3afb212689e2c7853d1b1ee45d +Subproject commit dceabd4c5a2aa8cb29ce5a05311a57519baadddc diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index a2e6bce8cfea..3bc3b5defaf2 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -40,7 +40,7 @@ def _get_cutlass_path(): return cutlass_path -def _get_cutlass_compile_options(sm, threads): +def _get_cutlass_compile_options(sm, threads, use_fast_math=False): cutlass_root = _get_cutlass_path() cutlass_include = os.path.join(cutlass_root, "include") cutlass_util_include = os.path.join(cutlass_root, "tools/util/include") @@ -58,6 +58,8 @@ def _get_cutlass_compile_options(sm, threads): "-I" + cutlass_include, "-I" + cutlass_util_include, ] + if use_fast_math: + kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID") cuda_path = find_cuda_path() cuda_ver = get_cuda_version(cuda_path) if cuda_ver >= 11.2: @@ -222,6 +224,10 @@ def handle_conv2d( cutlass_op_def = out["opdef_bias_relu"] elif op_type == "cutlass.conv2d_bias_sigmoid": cutlass_op_def = out["opdef_bias_sigmoid"] + elif op_type == "cutlass.conv2d_bias_silu": + cutlass_op_def = out["opdef_bias_silu"] + elif op_type == "cutlass.conv2d_bias_hardswish": + cutlass_op_def = out["opdef_bias_hardswish"] else: raise ValueError("%s pattern is not implemented." % op_type) @@ -339,7 +345,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t return mod, num_cutlass_partition -def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1): +def build_cutlass_kernels( + lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1, use_fast_math=False +): """Compile CUTLASS kernels in lib and return the runtime module ready to run. Parameters @@ -361,18 +369,27 @@ def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so", threa The number of threads to use for compiling generated kernels. Only available for CUDA 11.2 or later. Use all physical cores by default. + use_fast_math : bool, optional + Whether or not to use faster but less accurate math intrinsics. + Returns ------- updated_lib : runtime.Module The updated module with compiled cutlass kernels. """ - kwargs = _get_cutlass_compile_options(sm, threads) + kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math) lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) return runtime.load_module(lib_path) def build_cutlass_kernels_vm( - vm_exec, sm, tmp_dir="./tmp", lib_path="compile.so", vmcode_path="vmcode.ro", threads=-1 + vm_exec, + sm, + tmp_dir="./tmp", + lib_path="compile.so", + vmcode_path="vmcode.ro", + threads=-1, + use_fast_math=False, ): """Compile CUTLASS kernels in vm_exec and return a VM executable ready to run. @@ -398,13 +415,16 @@ def build_cutlass_kernels_vm( The number of threads to use for compiling generated kernels. Only available for CUDA 11.2 or later. Use all physical cores by default. + use_fast_math : bool, optional + Whether or not to use faster but less accurate math intrinsics. + Returns ------- updated_vm_exec: vm.Executable The updated exectuable with compiled cutlass kernels. """ code, lib = vm_exec.save() - kwargs = _get_cutlass_compile_options(sm, threads) + kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math) lib_path = os.path.join(tmp_dir, lib_path) vmcode_path = os.path.join(tmp_dir, vmcode_path) lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 288f67f39287..43317f9054bb 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -90,9 +90,17 @@ def create_conv2d_operator( EpilogueFunctor.LinearCombinationBias, EpilogueFunctor.LinearCombinationRelu, EpilogueFunctor.LinearCombinationSigmoid, + EpilogueFunctor.LinearCombinationSilu, + EpilogueFunctor.LinearCombinationHardSwish, ], - ["opdef_bias", "opdef_bias_relu", "opdef_bias_sigmoid"], - [True, True, False], + [ + "opdef_bias", + "opdef_bias_relu", + "opdef_bias_sigmoid", + "opdef_bias_silu", + "opdef_bias_hardswish", + ], + [True, True, False, False, False], ): op = Conv2dOperation( ConvKind.Fprop, diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 8c3f5eb5df63..efc5dd5ccd97 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,line-too-long """Various type definitions to help instantiate CUTLASS kernels.""" import re import enum @@ -149,6 +149,8 @@ class EpilogueFunctor(enum.Enum): LinearCombinationBias = enum_auto() LinearCombinationGelu = enum_auto() LinearCombinationSigmoid = enum_auto() + LinearCombinationSilu = enum_auto() + LinearCombinationHardSwish = enum_auto() EpilogueFunctorTag = { @@ -157,6 +159,8 @@ class EpilogueFunctor(enum.Enum): EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination", EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU", EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid", + EpilogueFunctor.LinearCombinationSilu: "cutlass::epilogue::thread::LinearCombinationSilu", + EpilogueFunctor.LinearCombinationHardSwish: "cutlass::epilogue::thread::LinearCombinationHardSwish", } diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 319062ddea41..24ccad5fe11d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1736,14 +1736,19 @@ def pad(inputs, input_types): paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)] const_paddings = [] + non_zero_found = False for pad in paddings: const_paddings.append([]) for p in pad: if not isinstance(p, int): p = int(_infer_value(p, {}).numpy()) const_paddings[-1].append(p) + if p != 0: + non_zero_found = True - if mode == "constant": + if not non_zero_found: + return data + elif mode == "constant": return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode) else: return _op.nn.pad(data, const_paddings, pad_mode=mode) diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 8fdd90ea109a..eb36dc2d7c9f 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -75,6 +75,15 @@ def make_conv2d_pattern(with_bias=False, with_act=None): return is_op("nn.relu")(conv2d_out) if with_act == "sigmoid": return is_op("sigmoid")(conv2d_out) + if with_act == "silu": + return is_op("multiply")(conv2d_out, is_op("sigmoid")(conv2d_out)) + if with_act == "hardswish": + rhs = is_op("divide")( + is_op("clip")(is_op("add")(conv2d_out, is_constant())), is_constant() + ) + return is_op("multiply")(conv2d_out, rhs) + + raise ValueError("Unknown activation %s." % with_act) return conv2d_out @@ -149,6 +158,16 @@ def partition_for_cutlass(mod, params=None): dense_bias_pat, dense_pat, ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), + ( + "cutlass.conv2d_bias_hardswish", + make_conv2d_pattern(with_bias=True, with_act="hardswish"), + check_conv2d, + ), + ( + "cutlass.conv2d_bias_silu", + make_conv2d_pattern(with_bias=True, with_act="silu"), + check_conv2d, + ), ( "cutlass.conv2d_bias_relu", make_conv2d_pattern(with_bias=True, with_act="relu"), @@ -180,7 +199,7 @@ def partition_for_cutlass(mod, params=None): [ transform.InferType(), transform.MergeComposite(cutlass_patterns), - transform.AnnotateTarget(["cutlass"]), + transform.AnnotateTarget(["cutlass"], include_non_call_ops=False), transform.PartitionGraph(bind_constants=False), ] ) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index d06ebaa896f4..a87ba2f2cf1d 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -173,13 +173,9 @@ void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string& kernel) std::string DenseOp(std::string id, const Str2StrMap& attrs, const std::vector& func_args) { - bool has_bias = false; + bool has_bias = attrs.at("op_type").find("bias") != std::string::npos; bool is_gelu = attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16 - if (attrs.at("op_type") == "cutlass.dense_bias" || - attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) { - has_bias = true; - } std::ostringstream gemm_decl; AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0, 1); @@ -263,10 +259,10 @@ Str2StrMap Conv2dArgs(const Map& attrs) { std::string Conv2dOp(std::string id, const Str2StrMap& attrs, const std::vector& func_args) { - bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" || - attrs.at("op_type") == "cutlass.conv2d_bias_relu" || - attrs.at("op_type") == "cutlass.conv2d_bias_sigmoid"; - bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid"; + bool has_bias = attrs.at("op_type").find("bias") != std::string::npos; + bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid" && + attrs.at("op_type") != "cutlass.conv2d_bias_silu" && + attrs.at("op_type") != "cutlass.conv2d_bias_hardswish"; std::ostringstream conv2d_decl; CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); @@ -505,6 +501,20 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "sigmoid"}); return GenerateBody(conv2d_call, "cutlass_conv2d_bias_sigmoid", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.conv2d_bias_silu") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; + const auto* conv2d_call = + GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "multiply"}); + return GenerateBody(conv2d_call, "cutlass_conv2d_bias_silu", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.conv2d_bias_hardswish") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; + const auto* conv2d_call = + GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "multiply"}); + return GenerateBody(conv2d_call, "cutlass_conv2d_bias_hardswish", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); } LOG(FATAL) << "Unknown composite function: " << pattern_name; @@ -546,14 +556,11 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi ret.outputs.push_back(output); } decl_stream << ");"; - if (func_name == "cutlass_dense" || func_name == "cutlass_dense_bias" || - func_name == "cutlass_dense_bias_relu" || func_name == "cutlass_dense_bias_gelu") { + if (func_name.find("dense") != std::string::npos) { ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); } else if (func_name == "cutlass_batch_matmul") { ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); - } else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" || - func_name == "cutlass_conv2d_bias_relu" || - func_name == "cutlass_conv2d_bias_sigmoid") { + } else if (func_name.find("conv2d") != std::string::npos) { ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args); } @@ -613,6 +620,9 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; ICHECK(ref->IsInstance()); auto res = GenCutlassFunc(Downcast(ref)); diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 64baa6066522..34e487cf1350 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -467,6 +467,9 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr int64_t num_axis = dshape.size(); const auto* begin = types[1].as(); + if (begin == nullptr) { + return false; + } ICHECK(begin); // calculate output shape diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 90a0e3150573..8a9b1a9505f6 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -289,6 +289,11 @@ bool StackRel(const Array& types, int num_inputs, const Attrs& attrs, << "cast: expect input type to be TupleType but get " << types[0]; return false; } + for (auto field : tensor_tuple->fields) { + if (field.as()) { + return false; + } + } const auto* param = attrs.as(); const auto& first = Downcast(tensor_tuple->fields[0]); const int ndim = static_cast(first->shape.size()); diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 89099c86dc58..b2bdb8ca91a0 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -138,27 +138,48 @@ def get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float16") return relay.sigmoid(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype)) -def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"): +def get_conv2d_nchw_bias_silu(d_shape, w_shape, padding, out_dtype="float16"): + conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype) + return conv_out * relay.sigmoid(conv_out) + + +def get_conv2d_nchw_bias_hardswish(d_shape, w_shape, padding, out_dtype="float16"): + conv2d_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype) + return conv2d_out * ( + relay.clip(conv2d_out + relay.const(3, dtype=out_dtype), a_min=0, a_max=6) + / relay.const(6, dtype=out_dtype) + ) + + +def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False): mod = partition_for_cutlass(mod) mod, num_cutlass_partition = tune_cutlass_kernels( mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir ) with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target="cuda", params=params) - lib = build_cutlass_kernels(lib, sm, tmp_dir, lib_path) + lib = build_cutlass_kernels(lib, sm, tmp_dir, lib_path, use_fast_math=use_fast_math) dev = tvm.device("cuda", 0) rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) return rt_mod, dev, num_cutlass_partition def profile_and_build_vm( - mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", vmcode_path="vmcode.ro" + mod, + params, + sm, + tmp_dir="./tmp", + lib_path="compile.so", + vmcode_path="vmcode.ro", + use_fast_math=False, ): mod = partition_for_cutlass(mod) mod, num_cutlass_partition = tune_cutlass_kernels(mod, sm, tmp_dir=tmp_dir) with tvm.transform.PassContext(opt_level=3): vm_exec = relay.vm.compile(mod, target="cuda", params=params) - vm_exec = build_cutlass_kernels_vm(vm_exec, sm, tmp_dir, lib_path, vmcode_path) + vm_exec = build_cutlass_kernels_vm( + vm_exec, sm, tmp_dir, lib_path, vmcode_path, use_fast_math=use_fast_math + ) dev = tvm.device("cuda", 0) return VirtualMachine(vm_exec, dev), dev, num_cutlass_partition @@ -335,6 +356,7 @@ def verify_conv2d( rtol=1e-5, use_cudnn_ref=False, run_benchmark=False, + use_fast_math=False, ): if not has_cutlass(): return @@ -357,13 +379,13 @@ def verify_conv2d( mod_weight_ohwi = convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}) if use_vm: - rt_mod, _, num_cutlass_partition = profile_and_build_vm(mod_weight_ohwi, params, sm) + rt_mod, _, num_cutlass_partition = profile_and_build_vm( + mod_weight_ohwi, params, sm, use_fast_math=use_fast_math + ) out = get_output_vm(rt_mod, ["data"], [np_data]) else: rt_mod, _, num_cutlass_partition = profile_and_build( - mod_weight_ohwi, - params, - sm, + mod_weight_ohwi, params, sm, use_fast_math=use_fast_math ) out = get_output(rt_mod, ["data"], [np_data]) @@ -438,11 +460,37 @@ def test_conv2d_fusion(): mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False ) + mod_nchw = get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float16") + verify_conv2d( + mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + ) + verify_conv2d( + mod_nchw, + mod_nchw, + d_shape, + w_shape, + sm=80, + atol=1e-3, + rtol=1e-3, + run_benchmark=False, + use_fast_math=True, + ) + mod_nchw = get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float32") verify_conv2d( mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False ) + mod_nchw = get_conv2d_nchw_bias_silu(d_shape, w_shape, padding, out_dtype="float32") + verify_conv2d( + mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + ) + + mod_nchw = get_conv2d_nchw_bias_hardswish(d_shape, w_shape, padding, out_dtype="float16") + verify_conv2d( + mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + ) + if __name__ == "__main__": pytest.main([__file__])