diff --git a/CMakeLists.txt b/CMakeLists.txt index 040d8129b72..d1d10720550 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -136,6 +136,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/host_ir/container.cpp ${NVFUSER_SRCS_DIR}/host_ir/executor.cpp ${NVFUSER_SRCS_DIR}/host_ir/host_ir.cpp + ${NVFUSER_SRCS_DIR}/host_ir/lower.cpp ${NVFUSER_SRCS_DIR}/id_model/circular_buffer_indexing.cpp ${NVFUSER_SRCS_DIR}/id_model/contiguity.cpp ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp @@ -170,7 +171,6 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp ${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp ${NVFUSER_SRCS_DIR}/multidevice/executor.cpp - ${NVFUSER_SRCS_DIR}/multidevice/lower_communication.cpp ${NVFUSER_SRCS_DIR}/multidevice/utils.cpp ${NVFUSER_SRCS_DIR}/mutator.cpp ${NVFUSER_SRCS_DIR}/non_divisible_split.cpp @@ -200,6 +200,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp + ${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp ${NVFUSER_SRCS_DIR}/rng.cpp ${NVFUSER_SRCS_DIR}/runtime/allocations.cpp ${NVFUSER_SRCS_DIR}/runtime/executor.cpp @@ -231,8 +232,10 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/registry.cpp ${NVFUSER_SRCS_DIR}/scheduler/registry_utils.cpp + ${NVFUSER_SRCS_DIR}/scheduler/resize.cpp ${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp ${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp + ${NVFUSER_SRCS_DIR}/scheduler/tools/domain_map.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/maxinfo_propagator.cpp @@ -294,13 +297,18 @@ endif() add_library(codegen_internal OBJECT ${NVFUSER_SRCS}) if(NOT MSVC) - # -Werror is not enabled, because of gcc 12.2 used in manylinux image. - # consider enable this when we upgrade. linking comment: - # https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266 - target_compile_options(codegen_internal PRIVATE - -Wall -Wno-unused-function - # -Werror - ) + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + target_compile_options(codegen_internal PRIVATE + -Wall -Wno-unused-function -Werror + # These warnings are not treated as errors because of gcc 12.2 used in + # manylinux image. consider enable this when we upgrade. + # linking comment: + # https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266 + -Wno-error=restrict -Wno-error=stringop-overflow) + else() + target_compile_options(codegen_internal PRIVATE + -Wall -Wno-unused-function -Werror) + endif() endif() target_compile_definitions(codegen_internal PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") @@ -440,6 +448,7 @@ if(BUILD_PYTHON) list(APPEND NVFUSER_PYTHON_SRCS ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp + ${NVFUSER_SRCS_DIR}/python_frontend/schedule_bindings.cpp ) add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS}) @@ -573,6 +582,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_resharding.cpp ${NVFUSER_ROOT}/tests/cpp/test_resize.cpp ${NVFUSER_ROOT}/tests/cpp/test_reduction_pointwise.cpp + ${NVFUSER_ROOT}/tests/cpp/test_rope.cpp ${NVFUSER_ROOT}/tests/cpp/test_scalar_hoisting.cpp ${NVFUSER_ROOT}/tests/cpp/test_scatter_gather.cpp ${NVFUSER_ROOT}/tests/cpp/test_sdpa_node.cpp @@ -584,6 +594,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_tensor_factories.cpp ${NVFUSER_ROOT}/tests/cpp/test_unary.cpp ${NVFUSER_ROOT}/tests/cpp/test_utils.cpp + ${NVFUSER_ROOT}/tests/cpp/test_vectorization_analysis.cpp ) if(BUILD_TEST) @@ -644,6 +655,7 @@ if(BUILD_TEST) set(MULTIDEVICE_TEST_SRCS) list(APPEND MULTIDEVICE_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp + ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_overlap.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp diff --git a/benchmarks/python/core.py b/benchmarks/python/core.py index d7b90033f56..aea3662c5cb 100644 --- a/benchmarks/python/core.py +++ b/benchmarks/python/core.py @@ -152,7 +152,6 @@ def torchprofile_timer(self) -> float: # Clear the internal profiler object to avoid accumulating function events and then restart the profiler # See PR: https://github.com/pytorch/pytorch/pull/125510 self.prof.profiler = None - self.prof.start() return self.current_time diff --git a/csrc/bfs.h b/csrc/bfs.h index 79a9e29b7db..206026053c8 100644 --- a/csrc/bfs.h +++ b/csrc/bfs.h @@ -549,6 +549,77 @@ class BFS { Direction allowed_direction_ = Direction::Undefined; }; +// Unlike the default BFS behavior, Expr is considered ready to +// visit as long as one of the inputs or outputs has any of its dependencies met +template < + typename ExprT, + typename ValT, + typename DefinitionT, + typename UsesT, + typename InputsT, + typename OutputsT> +class BFSWithPermissiveDependence + : public BFS { + public: + using NodeType = + typename BFS:: + NodeType; + + BFSWithPermissiveDependence( + DefinitionT definition, + UsesT uses, + InputsT inputs, + OutputsT outputs, + std::vector from, + std::vector to, + bool require_all_to_visited = true, + Direction allowed_direction = Direction::Undefined) + : BFS( + definition, + uses, + inputs, + outputs, + std::move(from), + std::move(to), + require_all_to_visited, + allowed_direction) {} + + std::optional>> isReady( + const ExprT& expr) const override { + // Either any inputs or any outputs must have been visited + decltype(auto) inputs = this->inputs_(expr); + if (!inputs.empty() && this->allowed_direction_ != Direction::Backward && + std::any_of( + inputs.begin(), inputs.end(), [&](const ValT& input) -> bool { + return this->isDependencySatisfied(input); + })) { + std::vector prev_nodes; + std::copy_if( + inputs.begin(), + inputs.end(), + std::back_inserter(prev_nodes), + [&](const ValT& input) -> bool { return this->isVisited(input); }); + return std::make_pair(Direction::Forward, prev_nodes); + } + + decltype(auto) outputs = this->outputs_(expr); + if (!outputs.empty() && this->allowed_direction_ != Direction::Forward && + std::any_of( + outputs.begin(), outputs.end(), [&](const ValT& output) -> bool { + return this->isDependencySatisfied(output); + })) { + std::vector prev_nodes; + std::copy_if( + outputs.begin(), + outputs.end(), + std::back_inserter(prev_nodes), + [&](const ValT& output) -> bool { return this->isVisited(output); }); + return std::make_pair(Direction::Backward, prev_nodes); + } + return std::nullopt; + } +}; + // Find the shortest path from the from vals to the to // vals. Dependency between vals and exprs must be satisfied. // It is an error if no valid path is found unless diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 0060e626fe6..3a5f31c74d5 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -3026,17 +3026,22 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } else { step_code << gen_index << " += " << gen_step; } - if (loop->isUnrolled()) { - indent() << "#pragma unroll\n"; - } else if ( - loop->circularBufferLoopStage() == CircularBufferLoopStage::Epilog) { - indent() << "#pragma unroll " << loop->circularBufferLoopStageDepth() - 1 - << "\n"; - } else if ( - loop->circularBufferLoopStage() != + if (loop->circularBufferLoopStage() != CircularBufferLoopStage::NotApplicable) { - indent() << "#pragma unroll " << loop->circularBufferLoopStageDepth() - << "\n"; + // NOTE: requireUnroll is sometimes called on a circular-buffered matmul + // loops when static shapes are used. To avoid hinting that the compiler + // should maximally unroll such loops leading to very long compiles, we + // handle that case explicitly here and ignore loop->isUnrolled(). + // + // Unroll "prefetch" many circular buffered loops regardless of buffer + // stage (prologue, main, or epilogue) + int64_t prefetch = kernel_->summary() + .circular_buffer_info + .getCircularBufferOptionsFor(loop->iter_domain()) + .prefetch; + indent() << "#pragma unroll " << prefetch << "\n"; + } else if (loop->isUnrolled()) { + indent() << "#pragma unroll\n"; } else { indent() << "#pragma unroll 1\n"; } @@ -3505,6 +3510,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << "NVFUSER_UPDATE_MAGIC_ZERO;\n"; } + void handle(const kir::Return* ret) final { + indent() << "return;\n"; + } + private: std::stringstream code_; const kir::Kernel* kernel_; diff --git a/csrc/device_lower/analysis/circular_buffer.cpp b/csrc/device_lower/analysis/circular_buffer.cpp index 05dfe1a3e9c..58f35a1f8f0 100644 --- a/csrc/device_lower/analysis/circular_buffer.cpp +++ b/csrc/device_lower/analysis/circular_buffer.cpp @@ -232,9 +232,11 @@ IterDomain* CircularBufferInfo::getCircularBufferAxis( const CircularBufferOptions& CircularBufferInfo::getCircularBufferOptionsFor( IterDomain* circular_buffer_axis) const { - auto concrete_id = lower_utils::getConcreteLoopID(circular_buffer_axis); + if (GpuLower::hasCurrent()) { + circular_buffer_axis = lower_utils::getConcreteLoopID(circular_buffer_axis); + } - auto maybe_depth_it = circular_buffer_options_.find(concrete_id); + auto maybe_depth_it = circular_buffer_options_.find(circular_buffer_axis); NVF_ERROR( maybe_depth_it != circular_buffer_options_.end(), diff --git a/csrc/device_lower/analysis/device_version.cpp b/csrc/device_lower/analysis/device_version.cpp index 98b4a7300d0..4682adfaf75 100644 --- a/csrc/device_lower/analysis/device_version.cpp +++ b/csrc/device_lower/analysis/device_version.cpp @@ -5,6 +5,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include + #include #include #include @@ -19,9 +21,22 @@ void MinimumDeviceVersion::dispatch(Val* val) { } if (val->dtype() == DataType::Float8_e4m3fn || val->dtype() == DataType::Float8_e5m2) { +// See release note +// https://docs.nvidia.com/cuda/archive/12.1.0/parallel-thread-execution/index.html#ptx-isa-version-8-1 +#if (CUDA_VERSION >= 12010) + ensureVersion( + {8, 9}, + "Fusion contains Float8_xxx values which was introduced in Ada (8.9)"); +// See release note +// https://docs.nvidia.com/cuda/archive/11.8.0/parallel-thread-execution/index.html#ptx-isa-version-7-8 +#elif (CUDA_VERSION >= 11080) ensureVersion( {9, 0}, "Fusion contains Float8_xxx values which was introduced in Hopper (9.0)"); +#else + NVF_ERROR( + "Fusion contains Float8_xxx values which was not supported in given CUDA version"); +#endif // (CUDA_VERSION >= 12010) } IterVisitor::dispatch(val); } diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index f4121021f3b..4605eb9eac4 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1686,8 +1686,8 @@ Val* hardCodedIndexGenerationForStMatrix( Val* out_index = nullptr; NVF_ERROR( - ldst->out()->dtype() == DataType::Half, - "we only support half type in stmatrix"); + dataTypeSize(ldst->out()->dtype()) == 2, + "we only support 16-bit types in stmatrix"); NVF_ERROR(ldst->out()->isA()); TensorView* out_tv = ldst->out()->as(); @@ -1959,8 +1959,8 @@ Val* hardCodedIndexGenerationForStMatrixSwizzle( "size not currently supported for stmatrix"); NVF_ERROR( - ldst->out()->dtype() == DataType::Half, - "we only support half type in stmatrix"); + dataTypeSize(ldst->out()->dtype()) == 2, + "we only support 16-bit types in stmatrix"); NVF_ERROR(ldst->out()->isA()); TensorView* out_tv = ldst->out()->as(); @@ -2583,6 +2583,16 @@ void IndexLowering::handle(const kir::WgMmaFence* fence) { pushBack(const_cast(fence)); // NOLINT } +void IndexLowering::handle(const kir::SetMaxNReg* maxnreg) { + // TODO(kir): remove the need for const_cast + pushBack(const_cast(maxnreg)); // NOLINT +} + +void IndexLowering::handle(const kir::Return* ret) { + // TODO(kir): remove the need for const_cast + pushBack(const_cast(ret)); // NOLINT +} + void IndexLowering::handle(const kir::AsyncCommit* commit) { // TODO(kir): remove the need for const_cast pushBack(const_cast(commit)); // NOLINT diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 8d206159128..4cd7d7cdfdc 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -75,6 +75,8 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::GridSync*) final; void handle(const kir::FenceAsyncProxy*) final; void handle(const kir::WgMmaFence*) final; + void handle(const kir::SetMaxNReg*) final; + void handle(const kir::Return*) final; void handle(const kir::MBarrierInit*) final; void handle(const kir::MBarrierInvalidate*) final; void handle(const kir::MBarrierArrive*) final; diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 31afc58a775..c27fd5294f6 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -272,6 +272,19 @@ class LowerToInlinePtx : public kir::ExprMutator { std::vector{}, kir::Asm::Options{/*volatile=*/true})); } + + void handle(kir::SetMaxNReg* maxnreg) final { + std::string ptx = (maxnreg->increaseRegisters()) + ? "setmaxnreg.inc.sync.aligned.u32" + : "setmaxnreg.dec.sync.aligned.u32"; + registerReplace( + maxnreg, + IrBuilder::create( + ptx, + std::vector{}, + std::vector{maxnreg->numberOfRegisters()}, + kir::Asm::Options{/*volatile=*/true})); + } }; std::vector lowerToInlinePtx(const std::vector& exprs) { diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 4fe0f86cc5f..7681aa878a1 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -120,6 +120,8 @@ class Val; f(GridSync); \ f(FenceAsyncProxy); \ f(WgMmaFence); \ + f(SetMaxNReg); \ + f(Return); \ f(MBarrierInit); \ f(MBarrierInvalidate); \ f(MBarrierArrive); \ @@ -146,6 +148,7 @@ class Val; f(HostUnit); \ f(PostOnStream); \ f(SetCurrentStream); \ + f(GetCurrentStream); \ f(Wait); \ f(Synchronize); \ f(StartCoalescing); \ diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index d825a2a941f..1b2554cdabb 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -6,13 +6,15 @@ */ // clang-format on +#include + #include #include #include +#include #include #include #include -#include #include #include #include @@ -34,14 +36,14 @@ bool HostIrExecutor::supported(Fusion* fusion) { FUSER_PERF_SCOPE("HostIrExecutor::supported"); std::vector exprs = fusion->exprs(); if (std::any_of(exprs.begin(), exprs.end(), [](Expr* e) { - return isResharding(e) && isLowerableToCommunication(e); + return isResharding(e) && HostIrLower::canLower(e); })) { NVF_ERROR( std::all_of( exprs.begin(), exprs.end(), [](Expr* e) { - return isResharding(e) && isLowerableToCommunication(e); + return isResharding(e) && HostIrLower::canLower(e); }), "Could not execute fusion as all expressions in a host IR container must be communication based at this point."); return true; @@ -67,8 +69,7 @@ void HostIrExecutor::compile(Fusion* fusion) { } else { std::vector exprs = fusion->exprs(); for (Expr* e : exprs) { - std::vector communications = - lowerCommunication(cloner.clone(e)); + std::vector communications = HostIrLower::lower(cloner.clone(e)); for (auto* communication : communications) { host_ir_container_->pushBackTopLevelExprs(communication); } @@ -187,7 +188,8 @@ HostIrEvaluator::HostIrEvaluator( HostIrEvaluatorParams params) : container_(std::move(container)), communicator_(communicator), - params_(params) { + params_(params), + my_device_index_(communicator_ ? communicator_->deviceId() : 0) { const DeviceIdxType device_index = (communicator_ != nullptr && communicator_->is_available()) ? communicator_->deviceId() @@ -218,6 +220,36 @@ std::vector HostIrEvaluator::runWithInput( return getKnownTensorOrUndefined(container_->outputs(), expr_evaluator_); } +std::string HostIrEvaluator::canRun() const { + const int64_t requested_n_gpus = requestedNumberOfDevices(container_.get()); + + if (requested_n_gpus == 1) { + return ""; + } + + if (communicator_ == nullptr) { + return "A communicator must be provided"; + } + + if (!communicator_->is_available()) { + return "distributed configuration required"; + } + + if (requested_n_gpus > communicator_->size()) { + return "the fusion requests " + std::to_string(requested_n_gpus) + + " GPUs to run, but there are only " + + std::to_string(communicator_->size()) + " ranks in the communicator"; + } + + if (communicator_->local_size() > at::cuda::getNumGPUs()) { + return std::to_string(communicator_->local_size()) + + " processes are spawn on the node but only " + + std::to_string(at::cuda::getNumGPUs()) + " GPUs are available"; + } + + return ""; +} + c10::cuda::CUDAStream HostIrEvaluator::getCUDAStream(Stream* stream) { StreamKey stream_key = stream; // if stream points to an index, it represents the dynamic value of that index @@ -242,8 +274,27 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { setCurrentCUDAStream(getCUDAStream(set_current_stream->stream())); } +void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) { + streams_.insert( + {get_current_stream->stream(), + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_))}); +} + void HostIrEvaluator::handle(Synchronize* synchronize) { - getCUDAStream(synchronize->stream()).synchronize(); + cudaStream_t current_stream = + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_)) + .stream(); + cudaStream_t stream_to_sync = getCUDAStream(synchronize->stream()).stream(); + + cudaEvent_t event = {}; + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(event, stream_to_sync)); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaStreamWaitEvent(current_stream, event, cudaEventWaitDefault)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event)); } void HostIrEvaluator::handle(PostOnStream* post_ir) { diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 76a27c2f5d1..a51dc32aed4 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -89,6 +89,10 @@ class HostIrEvaluator final : public OptOutDispatch { return container_->inputs(); } + const std::vector& outputs() { + return container_->outputs(); + } + std::ostream& print(std::ostream& os) const { return container_->print(os); }; @@ -101,9 +105,14 @@ class HostIrEvaluator final : public OptOutDispatch { return streams_; } + // check if the runtime is valid returns an error msg. + // An empty message means that the runtime is valid + std::string canRun() const; + private: using OptOutDispatch::handle; void handle(SetCurrentStream* set_current_stream) override; + void handle(GetCurrentStream* get_current_stream) override; void handle(Synchronize* synchronize) override; void handle(PostOnStream* post_ir) override; void handle(Communication* communication) override; @@ -130,6 +139,7 @@ class HostIrEvaluator final : public OptOutDispatch { using StreamKey = std::variant; std::unordered_map streams_; std::unordered_map> works_; + const int64_t my_device_index_; }; } // namespace hir diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 492b2b22aab..49b33f59823 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -179,6 +179,22 @@ bool SetCurrentStream::sameAs(const Statement* other) const { return false; } +GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR(passkey.ir_container_->isA()); + auto stream = IrBuilder::createInContainer(passkey.ir_container_); + addAttribute(stream); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream) + +std::string GetCurrentStream::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString() + << std::endl; + return ss.str(); +} + Wait::Wait(IrBuilderPasskey passkey, Expr* expr) : Expr(passkey, {}, {}, {expr}) { NVF_ERROR(passkey.ir_container_ != nullptr); diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index 587ffc43638..82d67d6f4cc 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -161,6 +161,28 @@ class SetCurrentStream : public Expr { } }; +class GetCurrentStream : public Expr { + public: + using Expr::Expr; + GetCurrentStream(IrBuilderPasskey passkey); + + GetCurrentStream(const GetCurrentStream& other) = delete; + GetCurrentStream& operator=(const GetCurrentStream& other) = delete; + GetCurrentStream(GetCurrentStream&& other) = delete; + GetCurrentStream& operator=(GetCurrentStream&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + const char* getOpString() const override { + return "hir::GetCurrentStream"; + } + + Stream* stream() const { + return attributes_.at(0)->as(); + } +}; + class Wait : public Expr { public: using Expr::Expr; diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/host_ir/lower.cpp similarity index 69% rename from csrc/multidevice/lower_communication.cpp rename to csrc/host_ir/lower.cpp index c9f410041da..8e97b958a9a 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/host_ir/lower.cpp @@ -6,13 +6,19 @@ */ // clang-format on #include +#include +#include #include #include #include #include -#include #include #include +#include +#include +#include +#include +#include #include namespace nvfuser { @@ -47,7 +53,7 @@ inline c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) { void lowerToScatter( TensorView* input_tv, TensorView* output_tv, - std::vector& comms) { + std::vector& comms) { // we arbitrarily choose the first device of the sender mesh to be the root const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); auto root = input_tv->getDeviceMesh().at(0); @@ -68,7 +74,7 @@ need multiple Gather if the tensor is replicated in the receiver mesh. void lowerToGather( TensorView* input_tv, TensorView* output_tv, - std::vector& comms) { + std::vector& comms) { // we create as many 'Gathers' as there are devices in the receiver mesh const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); for (auto root : output_tv->getDeviceMesh().vector()) { @@ -85,7 +91,7 @@ void lowerToGather( void lowerToAllgather( TensorView* input_tv, TensorView* output_tv, - std::vector& comms) { + std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); comms.push_back(IrBuilder::create( CommunicationType::Allgather, output_tv, input_tv, mesh.vector())); @@ -96,7 +102,7 @@ void lowerToBroadcast( TensorView* input_tv, TensorView* output_tv, DeviceIdxType root, - std::vector& comms) { + std::vector& comms) { const DeviceMesh& mesh = output_tv->getDeviceMesh(); Team team = mesh.vector(); if (!mesh.has(root)) { @@ -113,7 +119,7 @@ void lowerToBroadcast( void lowerToBroadcastOrSendRecv( TensorView* input_tv, TensorView* output_tv, - std::vector& comms) { + std::vector& comms) { const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); if (isSharded(input_tv) && sender_mesh.size() > 1) { @@ -154,7 +160,7 @@ void lowerToReduce( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - std::vector& comms) { + std::vector& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const auto reduce_op_type = getC10dReduceOpType(op_type); @@ -178,7 +184,7 @@ void lowerToAllreduce( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - std::vector& comms) { + std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); comms.push_back(IrBuilder::create( CommunicationType::Allreduce, @@ -193,7 +199,7 @@ void lowerToReduceScatter( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - std::vector& comms) { + std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); auto reduction_axis = output_tv->getReductionAxis().value(); auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx); @@ -226,10 +232,10 @@ void lowerToReduceScatter( sources *) Leverage the topology to ensure that the senders and recerivers are close */ -std::vector lowerCommunication(Expr* c) { +std::vector HostIrLower::lower(Expr* c) { FusionGuard fg(c->fusion()); - std::vector comms; + std::vector comms; NVF_ERROR( c->inputs().size() == 1 && c->input(0)->isA() && c->outputs().size() == 1 && c->output(0)->isA(), @@ -251,7 +257,7 @@ std::vector lowerCommunication(Expr* c) { isSharded(output_tv) && receiver_mesh.size() > 1; NVF_ERROR( - isLowerableToCommunication(c), + HostIrLower::canLower(c), "Lowering expression ", c->toString(), " to communication is not supported"); @@ -296,7 +302,10 @@ std::vector lowerCommunication(Expr* c) { return comms; } -bool isLowerableToCommunication(Expr* expr) { +bool HostIrLower::canLower(Expr* expr) { + if (!isResharding(expr)) { + return true; + } if (!ir_utils::isTvOp(expr)) { return false; } @@ -325,4 +334,100 @@ bool isLowerableToCommunication(Expr* expr) { } } +std::unique_ptr HostIrLower::lower( + std::unique_ptr fusion, + int64_t my_device_index) { + // Sharding PreSegmenter passes. + // Note: passes run before PreSegmenter optimization passes. + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + preseg_passes::OptimizationPass< + preseg_passes::InsertReshardingsPass>::runPass(fusion.get()); + preseg_passes::OptimizationPass< + preseg_passes::ReorderShardedAxisPass>::runPass(fusion.get()); + preseg_passes::OptimizationPass< + preseg_passes::MakeReshardingContiguousPass>::runPass(fusion.get()); + + // Performs segmentation at the inter-device communications + // Each SegmentedGroup represents a pipeline's stage, and can be either + // 1) a Fusion which doesn't involve inter-device communication + // 2) a Fusion comprised of one Expr, representing inter-device communication + SegmentCandidateFinderOptions options{ + .run_translate_welford = false, + .run_combine_reductions = false, + .run_herrmann_merge = true, + .run_final_merge = true, + .only_segment_resharding_exprs = true}; + std::unique_ptr staged_fusion = + SegmentCandidateFinder::segment(std::move(fusion), nullptr, options); + // Infer a topologically ordered traversal of the segmented fusion to + // determine the order for launching the kernels/comms + RuntimeWorkSpace workspace; + prepareRuntimeOrder(staged_fusion.get(), workspace); + + // Create the HostIrContainer representing the host program. Each segment of + // the segmented fusion will be translated to a HostIR + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + IrCloner ir_cloner(hic.get()); + auto clone = + [&ir_cloner](const std::vector& vals) -> std::vector { + std::vector cloned_vals(vals.size()); + std::transform( + vals.begin(), vals.end(), cloned_vals.begin(), [&ir_cloner](Val* val) { + return ir_cloner.clone(val); + }); + return cloned_vals; + }; + + for (auto group : workspace.group_run_order) { + std::vector host_exprs; + NVF_ERROR(!group->exprs().empty(), "invalid segmentation"); + if (involvedDevices(group->exprs().at(0)).count(my_device_index) == 0) { + continue; + } + const bool is_resharding = std::any_of( + group->exprs().begin(), group->exprs().end(), [](auto expr) { + return isResharding(expr); + }); + if (is_resharding) { + NVF_ERROR( + group->exprs().size() == 1, + "Communication segments must contain only one Expr"); + for (auto* expr : + HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) { + // Allocate the recv buffers of communications + NVF_ERROR( + expr->isA(), + "Expected a Communication but got ", + expr); + auto* communication = expr->as(); + TensorView* tv = communication->out(); + if (tv->getDeviceMesh().has(my_device_index)) { + auto* allocate = + IrBuilder::create(tv, MemoryType::Global); + hic->pushBackTopLevelExprs(allocate); + } + hic->pushBackTopLevelExprs(communication); + auto wait = IrBuilder::create(communication); + hic->pushBackTopLevelExprs(wait); + } + } else { + auto host_unit = IrBuilder::create( + staged_fusion->makeFusion(group).second); + auto post_on_stream = IrBuilder::create( + host_unit, clone(group->inputs()), clone(group->outputs())); + hic->pushBackTopLevelExprs(post_on_stream); + } + } + for (auto input : staged_fusion->inputs()) { + hic->addInput(ir_cloner.clone(input)); + } + for (auto output : staged_fusion->outputs()) { + hic->addOutput(ir_cloner.clone(output)); + } + + return hic; +} + } // namespace nvfuser diff --git a/csrc/multidevice/lower_communication.h b/csrc/host_ir/lower.h similarity index 52% rename from csrc/multidevice/lower_communication.h rename to csrc/host_ir/lower.h index fe4c853373f..6a1d44247d2 100644 --- a/csrc/multidevice/lower_communication.h +++ b/csrc/host_ir/lower.h @@ -7,17 +7,23 @@ // clang-format on #pragma once +#include #include #include #include namespace nvfuser { -// Returns whether we support transforming a given expression into a series -// of communication. -bool isLowerableToCommunication(Expr* expr); +class HostIrLower { + public: + static bool canLower(Expr* expr); -// Lower a PipelineCommunication into a series of Communication. -std::vector lowerCommunication(Expr* c); + // Lower a sharded Expr into a series of Communication. + static std::vector lower(Expr* c); + + static std::unique_ptr lower( + std::unique_ptr fusion, + int64_t my_device_index); +}; } // namespace nvfuser diff --git a/csrc/id_model/indexing_traversal.cpp b/csrc/id_model/indexing_traversal.cpp index c2a6127a861..76f59cecb73 100644 --- a/csrc/id_model/indexing_traversal.cpp +++ b/csrc/id_model/indexing_traversal.cpp @@ -44,6 +44,26 @@ IndexingTraversal::IndexingTraversal( } resize_paths_.insert(resize); } + + // A unique expr path should be always allowed + for (const auto& expr_g : graph.disjointExprSets().disjointSets()) { + auto resize = dynamic_cast(expr_g->front()); + if (resize == nullptr) { + continue; + } + + auto input_groups = graph.inputGroups(expr_g); + auto output_groups = graph.outputGroups(expr_g); + NVF_ERROR(input_groups.size() == 1); + NVF_ERROR(output_groups.size() == 1); + + if (graph.getUses(input_groups[0]).size() != 1 || + graph.getDefinitions(output_groups[0]).size() != 1) { + continue; + } + + resize_paths_.insert(resize); + } } std::optional IndexingTraversal:: @@ -65,18 +85,26 @@ std::optional IndexingTraversal:: /*build_graphs=*/false); // Gather all resize exprs for each of the inputs and outputs - std::unordered_map> tv_resize_map; - for (auto inp : ir_utils::filterByType(expr->inputs())) { - for (auto expr : inp->domain()->allExprs()) { + std::unordered_map> tv_resize_map; + for (auto inp : expr->inputs()) { + auto inp_tv = ir_utils::getTv(inp); + if (inp_tv == nullptr) { + continue; + } + for (auto expr : inp_tv->domain()->allExprs()) { if (auto resize = dynamic_cast(expr)) { - tv_resize_map[inp].push_back(resize); + tv_resize_map[inp_tv].push_back(resize); } } } - for (auto out : ir_utils::filterByType(expr->outputs())) { - for (auto expr : out->domain()->allExprs()) { + for (auto out : expr->outputs()) { + auto out_tv = ir_utils::getTv(out); + if (out_tv == nullptr) { + continue; + } + for (auto expr : out_tv->domain()->allExprs()) { if (auto resize = dynamic_cast(expr)) { - tv_resize_map[out].push_back(resize); + tv_resize_map[out_tv].push_back(resize); } } } @@ -149,9 +177,17 @@ std::optional IndexingTraversal:: }; bool single_id_resized_multiple_times = false; - for (auto out : ir_utils::filterByType(expr->outputs())) { - for (auto inp : ir_utils::filterByType(expr->inputs())) { - if (isSingleIdResizedMultipleTimes(inp, out)) { + for (auto out : expr->outputs()) { + auto out_tv = ir_utils::getTv(out); + if (out_tv == nullptr) { + continue; + } + for (auto inp : expr->inputs()) { + auto inp_tv = ir_utils::getTv(inp); + if (inp_tv == nullptr) { + continue; + } + if (isSingleIdResizedMultipleTimes(inp_tv, out_tv)) { single_id_resized_multiple_times = true; break; } diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 7352e7fc96e..df9b0bf50c9 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -320,6 +320,8 @@ class NVF_API UnaryOp : public Expr { return "UnaryOp"; } + std::string getGraphvizLabel() const override; + std::vector evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const override; @@ -358,6 +360,8 @@ class NVF_API BinaryOp : public Expr { return "BinaryOp"; } + std::string getGraphvizLabel() const override; + std::vector evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const override; @@ -405,6 +409,8 @@ class TernaryOp : public Expr { return "TernaryOp"; } + std::string getGraphvizLabel() const override; + std::vector evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const override; @@ -1445,15 +1451,15 @@ class NVF_API MmaOp : public Expr { return attribute(ATTR_POS_MACRO); } - int m() const { + int64_t m() const { return getM(macro()); } - int n() const { + int64_t n() const { return getN(macro()); } - int k() const { + int64_t k() const { return getK(macro()); } diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 423035367ae..5f0528991b2 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -557,6 +557,12 @@ std::string UnaryOp::toInlineString(int indent_size) const { return ss.str(); } +std::string UnaryOp::getGraphvizLabel() const { + std::stringstream ss; + ss << getOpString() << "(" << getUnaryOpType() << ")"; + return ss.str(); +} + NVFUSER_DEFINE_CLONE_AND_CREATE(UnaryOp) BinaryOp::BinaryOp( @@ -724,6 +730,12 @@ std::string BinaryOp::toInlineString(int indent_size) const { return ss.str(); } +std::string BinaryOp::getGraphvizLabel() const { + std::stringstream ss; + ss << getOpString() << "(" << getBinaryOpType() << ")"; + return ss.str(); +} + NVFUSER_DEFINE_CLONE_AND_CREATE(BinaryOp) TernaryOp::TernaryOp( @@ -825,6 +837,12 @@ std::string TernaryOp::toInlineString(int indent_size) const { return ss.str(); } +std::string TernaryOp::getGraphvizLabel() const { + std::stringstream ss; + ss << getOpString() << "(" << getTernaryOpType() << ")"; + return ss.str(); +} + NVFUSER_DEFINE_CLONE_AND_CREATE(TernaryOp) ArrayConstruct::ArrayConstruct( @@ -1250,7 +1268,17 @@ BroadcastOp::BroadcastOp( std::string BroadcastOp::toString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << out()->toString() << "\n"; - indent(ss, indent_size) << " = broadcast( " << in()->toString() << " )\n"; + indent(ss, indent_size) << " = broadcast( " << in()->toString() + << ", flags = {"; + bool is_first = true; + for (const auto f : getBroadcastDimFlags()) { + if (!is_first) { + ss << ", "; + } + ss << (f ? "true" : "false"); + is_first = false; + } + ss << "} )\n"; return ss.str(); } @@ -2966,6 +2994,59 @@ void validateContiguity( } } +// Check if loop_domain is a valid domain with no +// redundancy. The logical domain is used as a reference to find if +// there's any ID that's not covered by the new loop domain. +void validateLoopDomain( + const std::vector& logical_domain, + const std::vector& loop_domain, + const std::vector& additional_ids) { + // Skip if there's any symbolic ID + if (std::any_of( + logical_domain.begin(), + logical_domain.end(), + [](IterDomain* id) { return id->isSymbolic(); }) || + std::any_of( + loop_domain.begin(), + loop_domain.end(), + [](IterDomain* id) { return id->isSymbolic(); }) || + std::any_of( + additional_ids.begin(), additional_ids.end(), [](IterDomain* id) { + return id->isSymbolic(); + })) { + return; + } + + std::vector reference; + reference.reserve(logical_domain.size() + additional_ids.size()); + reference.insert( + reference.end(), logical_domain.begin(), logical_domain.end()); + // additional_ids are also considered part of the refernece domain + reference.insert( + reference.end(), additional_ids.begin(), additional_ids.end()); + + auto [redundant_ids, _, unreachable_reference_ids] = + ir_utils::compareDomainWithReference(loop_domain, reference); + + auto empty_or_broadcast = [](const auto& ids) { + return std::all_of(ids.begin(), ids.end(), [](IterDomain* id) { + return id->isBroadcast(); + }); + }; + + NVF_ERROR( + empty_or_broadcast(redundant_ids), + "Trying to set a loop domain with non-broadcast redundant IDs: ", + toDelimitedString(redundant_ids)); + + NVF_ERROR( + empty_or_broadcast(unreachable_reference_ids), + "Not all logical IDs are covered by loop domain. Loop: ", + toDelimitedString(loop_domain), + ". Unreachable logical IDs: ", + toDelimitedString(unreachable_reference_ids)); +} + } // namespace TensorDomain::TensorDomain( @@ -3036,8 +3117,7 @@ TensorDomain::TensorDomain( NVF_CHECK( loop_domain_.empty() == logical_domain_.empty(), "logical domain and loop domain can only be both empty or neither empty"); - ir_utils::validateDomainEquivalence( - logical_domain_, loop_domain_, additional_ids_); + validateLoopDomain(logical_domain_, loop_domain_, additional_ids_); // resetDomains initializes other member variables, required by clang-tidy resetDomains(); @@ -3061,8 +3141,7 @@ TensorDomain::TensorDomain( NVF_CHECK( loop_domain_.empty() == logical_domain_.empty(), "logical domain and loop domain can only be both empty or neither empty"); - ir_utils::validateDomainEquivalence( - logical_domain_, loop_domain_, additional_ids_); + validateLoopDomain(logical_domain_, loop_domain_, additional_ids_); if (!root_domain_.empty()) { ir_utils::validateDomainEquivalence( logical_domain_, root_domain_, additional_ids_); @@ -3095,8 +3174,7 @@ TensorDomain::TensorDomain( NVF_CHECK( loop_domain_.empty() == logical_domain_.empty(), "logical domain and loop domain can only be both empty or neither empty"); - ir_utils::validateDomainEquivalence( - logical_domain_, loop_domain_, additional_ids_); + validateLoopDomain(logical_domain_, loop_domain_, additional_ids_); if (!root_domain_.empty()) { ir_utils::validateDomainEquivalence( logical_domain_, root_domain_, additional_ids_); @@ -3670,33 +3748,7 @@ std::pair TensorDomain::rFactor( } void TensorDomain::setLoopDomain(std::vector new_loop_domain) { - // Check if new_loop_domain is a valid domain with no - // redundancy. The logical domain is used as a reference to find if - // there's any ID that's not covered by the new loop domain. - std::vector reference; - reference.reserve(logical_domain_.size() + additional_ids_.size()); - reference.insert( - reference.end(), logical_domain_.begin(), logical_domain_.end()); - // additional_ids_ are also considered part of the refernece domain - reference.insert( - reference.end(), additional_ids_.begin(), additional_ids_.end()); - auto [redundant_ids, additional_ids, unreachable_reference_ids] = - ir_utils::compareDomainWithReference(new_loop_domain, reference); - NVF_ERROR( - redundant_ids.empty(), - "Trying to set a loop domain with redundant IDs: ", - toDelimitedString(redundant_ids)); - if (!unreachable_reference_ids.empty()) { - NVF_ERROR( - std::all_of( - unreachable_reference_ids.begin(), - unreachable_reference_ids.end(), - [](const auto id) { return id->isBroadcast(); }), - "Not all logical IDs are covered by loop domain. Loop: ", - toDelimitedString(new_loop_domain), - ". Unreachable logical IDs: ", - toDelimitedString(unreachable_reference_ids)); - } + validateLoopDomain(logical(), new_loop_domain, additionalIDs()); loop_domain_ = std::move(new_loop_domain); initial_loop_domain_ = loop_domain_; resetDomains(); @@ -3716,10 +3768,10 @@ void TensorDomain::setAllocationDomain( std::vector TensorDomain::allIDs() const { std::array*, 6> all_domains = { + &loop_domain_, &logical_domain_, &root_domain_, &initial_loop_domain_, - &loop_domain_, &allocation_domain_, &additional_ids_}; VectorOfUniqueEntries discovered_ids; @@ -4375,6 +4427,39 @@ std::vector CatOp::evaluate( return {at::cat(unpadded_inputs, concat_dim)}; } +namespace { + +// Given a tensorview, compute the strides according to the allocation domain +// for re-striding the corresponding ATen tensor. +std::vector computeStrides( + TensorView* tv, + const c10::IntArrayRef sizes) { + const auto& logical_domain = tv->getLogicalDomain(); + const auto& allocation_domain = tv->getMaybeAllocationDomain(); + + std::optional> out_order = ir_utils::computePermutation( + TensorDomain::noReductions(logical_domain), + TensorDomain::noReductions(allocation_domain)); + NVF_CHECK( + out_order.has_value(), + "Valid permute from logical to allocation domain was not found."); + + auto rank = sizes.size(); + std::vector sorted_strides(rank); + auto permuted_sizes = ir_utils::applyPermutation(sizes.vec(), *out_order); + sorted_strides[rank - 1] = 1; + for (int64_t idx = (int64_t)rank - 2; idx >= 0; idx--) { + sorted_strides[idx] = permuted_sizes[idx + 1] * sorted_strides[idx + 1]; + } + // Rearrange the strides in correct order of allocation + std::vector strides(rank); + for (auto idx : c10::irange(rank)) { + strides[out_order.value()[idx]] = sorted_strides[idx]; + } + return strides; +} +} // namespace + MatmulOp::MatmulOp(IrBuilderPasskey passkey, Val* out, Val* in_a, Val* in_b) : Expr(passkey) { addOutput(out); @@ -4401,7 +4486,17 @@ std::vector MatmulOp::evaluate( const std::vector& inputs) const { const auto a = inputs.at(0).as(); const auto b = inputs.at(1).as(); - return {at::matmul(a, b)}; + + auto matmul_out = at::matmul(a, b); + if (ir_utils::hasTrivialAllocationDomain(out())) { + return {matmul_out}; + } + auto matmul_sizes = matmul_out.sizes(); + auto strides = computeStrides(out(), matmul_sizes); + auto strided_matmul_out = + at::empty_strided(matmul_sizes, strides, a.options()); + strided_matmul_out = strided_matmul_out.copy_(matmul_out); + return {strided_matmul_out}; } LinearOp::LinearOp( diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 7eadac6abaa..4de5d7c8097 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -938,14 +938,13 @@ CompareDomainWithReferenceResult compareDomainWithReference( // the reference domain. If it's connected even with missing // dependencies, it should be considered redundant. For this // reason, a variant of IRBFS with a relaxed dependency condition - // is used. IRBFSWithPermissiveDependence can traverse as long as one of + // is used. IRPermissiveBFS can traverse as long as one of // the inputs or outputs is visited. - const auto from_remaining_ids = - getExprsBetween( - {unused_ids.begin(), unused_ids.end()}, - {reference.begin(), reference.end()}, - /*require_all_to_visited=*/false) - .first; + const auto from_remaining_ids = getExprsBetween( + {unused_ids.begin(), unused_ids.end()}, + {reference.begin(), reference.end()}, + /*require_all_to_visited=*/false) + .first; // Nothing is reachable, which means all of the unused IDs are not redundant if (from_remaining_ids.empty()) { additional_ids = unused_ids; diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index a174b47a2c6..f27c89068f6 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -634,55 +634,28 @@ inline std::vector getOutputsOfExpr(Expr* expr, Direction dir) { return getOutputsOfExpr(expr, dir, IRInputs(), IROutputs()); } -// Unlike the default IRBFS behavior, Expr is considered ready to -// visit as long as one of the inputs or outputs has its dependency met -class IRBFSWithPermissiveDependence : public IRBFS { +class IRPermissiveBFS : public BFSWithPermissiveDependence< + Expr*, + Val*, + IRDefinitions, + IRUses, + IRInputs, + IROutputs> { public: - IRBFSWithPermissiveDependence( - const std::vector& from_ids, - const std::vector& to_ids, - bool require_all_to_visited = true, + IRPermissiveBFS( + std::vector from_groups, + std::vector to_groups, + bool require_all_to_visited, Direction allowed_direction = Direction::Undefined) - : IRBFS( - {from_ids.begin(), from_ids.end()}, - {to_ids.begin(), to_ids.end()}, + : BFSWithPermissiveDependence( + IRDefinitions{}, + IRUses{}, + IRInputs{}, + IROutputs{}, + std::move(from_groups), + std::move(to_groups), require_all_to_visited, allowed_direction) {} - - std::optional>> isReady( - const ExprType& expr) const override { - // Either any inputs or any outputs must have been visited - decltype(auto) inputs = inputs_(expr); - if (!inputs.empty() && allowed_direction_ != Direction::Backward && - std::any_of( - inputs.begin(), inputs.end(), [&](const ValType& input) -> bool { - return isDependencySatisfied(input); - })) { - std::vector prev_nodes; - std::copy_if( - inputs.begin(), - inputs.end(), - std::back_inserter(prev_nodes), - [&](const ValType& input) -> bool { return isVisited(input); }); - return std::make_pair(Direction::Forward, prev_nodes); - } - - decltype(auto) outputs = outputs_(expr); - if (!outputs.empty() && allowed_direction_ != Direction::Forward && - std::any_of( - outputs.begin(), outputs.end(), [&](const ValType& output) -> bool { - return isDependencySatisfied(output); - })) { - std::vector prev_nodes; - std::copy_if( - outputs.begin(), - outputs.end(), - std::back_inserter(prev_nodes), - [&](const ValType& output) -> bool { return isVisited(output); }); - return std::make_pair(Direction::Backward, prev_nodes); - } - return std::nullopt; - } }; } // namespace nvfuser diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index fc464eac315..a1549ba28f6 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -485,6 +485,47 @@ std::string WgMmaFence::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(WgMmaFence) +SetMaxNReg::SetMaxNReg( + IrBuilderPasskey passkey, + Val* number_of_registers, + bool increase_registers) + : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + addInput(number_of_registers); + addDataAttribute(increase_registers); +} + +std::string SetMaxNReg::toString(int indent_size) const { + return (increaseRegisters()) ? "setmaxnreg.inc.sync.aligned.u32" + : "setmaxnreg.dec.sync.aligned.u32"; +} + +std::string SetMaxNReg::toInlineString(int indent_size) const { + NVF_CHECK(false, "SetMaxNReg can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(SetMaxNReg) + +Return::Return(IrBuilderPasskey passkey) : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + +std::string Return::toString(int indent_size) const { + return "return"; +} + +std::string Return::toInlineString(int indent_size) const { + NVF_CHECK(false, "Return can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Return) + MBarrierInit::MBarrierInit( IrBuilderPasskey passkey, Val* mbarrier, diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 60421db1995..e8a68bd8eb3 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -41,6 +41,8 @@ class BlockSync; class GridSync; class FenceAsyncProxy; class WgMmaFence; +class SetMaxNReg; +class Return; class MBarrierInit; class MBarrierInvalidate; class MBarrierArrive; @@ -469,6 +471,50 @@ class WgMmaFence final : public Expr { std::string toInlineString(int indent_size = 0) const override; }; +// PTX: setmaxnreg.inc.sync.aligned.u32 and setmaxnreg.dec.sync.aligned.u32 +class SetMaxNReg final : public Expr { + public: + using Expr::Expr; + + explicit SetMaxNReg( + IrBuilderPasskey passkey, + Val* number_of_registers, + bool increase_registers); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return (increaseRegisters()) ? "IncSetMaxNReg" : "DecSetMaxNReg"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + bool increaseRegisters() const { + return attribute(0); + } + + Val* numberOfRegisters() const { + return input(0); + } +}; + +class Return final : public Expr { + public: + using Expr::Expr; + + explicit Return(IrBuilderPasskey passkey); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "Return"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; +}; + class MBarrierInit final : public Expr { public: using Expr::Expr; diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index 8197ea224f4..6cf1a499bb9 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include @@ -196,6 +197,8 @@ Communicator::Communicator( return; } + NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank_)); + #ifdef NVFUSER_DISTRIBUTED c10d::TCPStoreOptions store_opts; { diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 157dd6b99dc..963b80812d3 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -5,23 +5,16 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include - #include -#include #include #include +#include #include #include #include #include #include -#include #include -#include -#include -#include -#include #include #include @@ -31,94 +24,9 @@ MultiDeviceExecutor::MultiDeviceExecutor( std::unique_ptr fusion, Communicator& comm, hir::HostIrEvaluatorParams params) - : comm_(comm), complete_fusion_(std::move(fusion)) { - // Sharding PreSegmenter passes. - // Note: passes run before PreSegmenter optimization passes. - preseg_passes::OptimizationPass< - preseg_passes::PropagateShardingsPass>::runPass(complete_fusion_.get()); - preseg_passes::OptimizationPass< - preseg_passes::InsertReshardingsPass>::runPass(complete_fusion_.get()); - preseg_passes::OptimizationPass< - preseg_passes::ReorderShardedAxisPass>::runPass(complete_fusion_.get()); - preseg_passes::OptimizationPass:: - runPass(complete_fusion_.get()); - - // Performs segmentation at the inter-device communications - // Each SegmentedGroup represents a pipeline's stage, and can be either - // 1) a Fusion which doesn't involve inter-device communication - // 2) a Fusion comprised of one Expr, representing inter-device communication - SegmentCandidateFinderOptions options{ - .run_translate_welford = false, - .run_combine_reductions = false, - .run_herrmann_merge = true, - .run_final_merge = true, - .only_segment_resharding_exprs = true}; - std::unique_ptr staged_fusion = - SegmentCandidateFinder::segment( - std::make_unique(*complete_fusion_), nullptr, options); - // Infer a topologically ordered traversal of the segmented fusion to - // determine the order for launching the kernels/comms - RuntimeWorkSpace workspace; - prepareRuntimeOrder(staged_fusion.get(), workspace); - - // Create the HostIrContainer representing the host program. Each segment of - // the segmented fusion will be translated to a HostIR - auto hic = std::make_unique(); - FusionGuard fg(hic.get()); - IrCloner ir_cloner(hic.get()); - auto clone = - [&ir_cloner](const std::vector& vals) -> std::vector { - std::vector cloned_vals(vals.size()); - std::transform( - vals.begin(), vals.end(), cloned_vals.begin(), [&ir_cloner](Val* val) { - return ir_cloner.clone(val); - }); - return cloned_vals; - }; - - for (auto group : workspace.group_run_order) { - std::vector host_exprs; - NVF_ERROR(!group->exprs().empty(), "invalid segmentation"); - if (involvedDevices(group->exprs().at(0)).count(comm_.deviceId()) == 0) { - continue; - } - const bool is_resharding = std::any_of( - group->exprs().begin(), group->exprs().end(), [](auto expr) { - return isResharding(expr); - }); - if (is_resharding) { - NVF_ERROR( - group->exprs().size() == 1, - "Communication segments must contain only one Expr"); - std::vector communications = - lowerCommunication(ir_cloner.clone(group->exprs().at(0))); - for (Communication* communication : communications) { - // Allocate the recv buffers of communications - TensorView* tv = communication->out(); - if (tv->getDeviceMesh().has(comm_.deviceId())) { - auto* allocate = - IrBuilder::create(tv, MemoryType::Global); - hic->pushBackTopLevelExprs(allocate); - } - hic->pushBackTopLevelExprs(communication); - auto wait = IrBuilder::create(communication); - hic->pushBackTopLevelExprs(wait); - } - } else { - auto host_unit = IrBuilder::create( - staged_fusion->makeFusion(group).second); - auto post_on_stream = IrBuilder::create( - host_unit, clone(group->inputs()), clone(group->outputs())); - hic->pushBackTopLevelExprs(post_on_stream); - } - } - for (auto input : staged_fusion->inputs()) { - hic->addInput(ir_cloner.clone(input)); - } - for (auto output : staged_fusion->outputs()) { - hic->addOutput(ir_cloner.clone(output)); - } - + : comm_(comm) { + std::unique_ptr hic = + HostIrLower::lower(std::move(fusion), comm.deviceId()); // Create the HostIrEvaluator representing the host program host_ir_executor_ = std::make_unique(std::move(hic), &comm, params); @@ -147,27 +55,6 @@ std::vector MultiDeviceExecutor::runWithInput( return host_ir_executor_->runWithInput(val_to_IValue); } -std::string MultiDeviceExecutor::validate() const { - if (!comm_.is_available()) { - return "distributed configuration required"; - } - - if (requestedNumberOfDevices(completeFusion()) > comm_.size()) { - return "the pipeline requests " + - std::to_string(requestedNumberOfDevices(completeFusion())) + - " GPUs to run, but there are only " + std::to_string(comm_.size()) + - " ranks in the communicator"; - } - - if (comm_.size() > at::cuda::getNumGPUs()) { - return std::to_string(comm_.local_size()) + - " processes are spawn on the node but only " + - std::to_string(at::cuda::getNumGPUs()) + " GPUs are available"; - } - - return ""; -} - std::ostream& MultiDeviceExecutor::print(std::ostream& os) { return host_ir_executor_->print(os); } diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index 4ff8065099c..7cad0388b18 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -84,14 +84,11 @@ class MultiDeviceExecutor { return &comm_; } - // Returns the Fusion - auto completeFusion() const { - return complete_fusion_.get(); - } - // check if the runtime is valid returns an error msg. // An empty message means that the runtime is valid - std::string validate() const; + std::string validate() const { + return host_ir_executor_->canRun(); + } //! Print to default debugging output stream std::ostream& print(std::ostream& os = debug()); @@ -103,8 +100,6 @@ class MultiDeviceExecutor { private: // holds the Communicator to be used for execution Communicator& comm_; - // holds the original complete fusion - std::unique_ptr complete_fusion_; // holds the HostIrEvaluator used for execution std::unique_ptr host_ir_executor_; }; diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 020f746f011..847557bfa3a 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -7,13 +7,13 @@ // clang-format on #include +#include #include #include #include #include #include #include -#include #include #include #include @@ -565,12 +565,12 @@ std::set involvedDevices(Expr* expr) { {ir_utils::filterByType(expr->inputs()), ir_utils::filterByType(expr->outputs())}) { for (auto* tv : tvs) { - NVF_ERROR( - tv->hasDeviceMesh(), - "the TensorView has no device mesh: ", - tv->toString()); - auto& mesh = tv->getDeviceMesh().vector(); - std::copy(mesh.begin(), mesh.end(), std::inserter(ret, ret.end())); + if (tv->hasDeviceMesh()) { + auto& mesh = tv->getDeviceMesh().vector(); + std::copy(mesh.begin(), mesh.end(), std::inserter(ret, ret.end())); + } else { + ret.insert(0); + } } } return ret; diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 1db959115a2..d2f0d9277d2 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -56,6 +56,59 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) { return dx; } +TensorView* triu(TensorView* tv, Val* offset) { + NVF_CHECK( + isIntegralType(offset->getDataType().value()), + "offset must have integral type"); + + // Let's say we want a triu of a 2D tensor of shape [2, 4] + // We broadcast the iota of the outer dim + // [0 [0, 0, 0, 0] + // 1] -> [1, 1, 1, 1] + // We broadcast the iota of the inner dim + // [0, 1, 2, 3] -> [0, 1, 2, 3] + // [0, 1, 2, 3] + // Using LE on the bcast tensors we get the mask + //[0, 0, 0, 0] LE [0, 1, 2, 3] + //[1, 1, 1, 1] [0, 1, 2, 3] + // Gives: + //[1, 1, 1, 1] + //[0, 1, 1, 1] + auto tv_logical_no_reductions = + TensorDomain::noReductions(tv->getLogicalDomain()); + auto dims = tv_logical_no_reductions.size(); + + NVF_CHECK( + dims >= 2, + "triu is only supported for 2+D tensors, but got ", + dims, + "D tensor"); + + auto fusion = tv->fusion(); + + auto tv_rows = iota( + tv_logical_no_reductions[dims - 2]->extent(), + fusion->zeroVal(DataType::Index), + fusion->oneVal(DataType::Index), + DataType::Index); + + // If triu has an offset of k, we shift/subtract the iota of the columns by k + // before broadcasting and comparing with the iota of the rows. + // So when building an iota op, instead of starting from 0 with a step of 1 + // we start from -offset (== -k) with a step of 1. + auto start_shifted_by_offset = SimplifyingIrBuilder::negExpr(offset); + auto tv_columns = iota( + tv_logical_no_reductions[dims - 1]->extent(), + start_shifted_by_offset, + fusion->oneVal(DataType::Index), + DataType::Index); + + auto tv_rows_b = broadcast(tv_rows, {false, true}); + auto tv_cols_b = broadcast(tv_columns, {true, false}); + auto mask = le(tv_rows_b, tv_cols_b); + return where(mask, tv, fusion->zeroVal(DataType::Index)); +} + namespace { TensorView* newForLinear( diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index ecbbb89b5a3..b67015b994d 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -35,6 +35,8 @@ NVF_API TensorView* dropout_backward( TensorView* mask, Val* scale); +NVF_API TensorView* triu(TensorView* tv, Val* offset); + struct LstmResult { TensorView* cell = nullptr; TensorView* hidden = nullptr; diff --git a/csrc/options.cpp b/csrc/options.cpp index 639e0c57622..0ca177aff7d 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -112,12 +112,13 @@ std::unordered_map> Options< {"expr_sort_verbose", DebugDumpOption::ExprSortVerbose}, {"ftrace", DebugDumpOption::FunctionTrace}, {"fusion_args", DebugDumpOption::FusionArgs}, - {"fusion_ir_original", DebugDumpOption::FusionIrOriginal}, - {"fusion_ir_concretized", DebugDumpOption::FusionIrConcretized}, - {"fusion_ir_preseg", DebugDumpOption::FusionIrPreseg}, - {"fusion_ir_presched", DebugDumpOption::FusionIrPresched}, {"fusion_ir", DebugDumpOption::FusionIr}, + {"fusion_ir_concretized", DebugDumpOption::FusionIrConcretized}, + {"fusion_ir_graph", DebugDumpOption::FusionIrGraph}, {"fusion_ir_math", DebugDumpOption::FusionIrMath}, + {"fusion_ir_original", DebugDumpOption::FusionIrOriginal}, + {"fusion_ir_presched", DebugDumpOption::FusionIrPresched}, + {"fusion_ir_preseg", DebugDumpOption::FusionIrPreseg}, {"global_zeroed_memory", DebugDumpOption::GlobalZeroedMemory}, {"host_ir", DebugDumpOption::HostIr}, {"index_type", DebugDumpOption::IndexType}, @@ -154,15 +155,16 @@ const std::unordered_map& getEnableOptions() { {"fuse_matmul", EnableOption::FuseMatmul}, {"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls}, {"id_model", EnableOption::IdModel}, + {"io_to_lower_precision", EnableOption::IoToLowerPrecision}, {"kernel_db", EnableOption::KernelDb}, + {"kernel_debug", EnableOption::KernelDebug}, + {"kernel_lineinfo", EnableOption::KernelLineInfo}, {"kernel_profile", EnableOption::KernelProfile}, {"memory_promotion", EnableOption::MemoryPromotion}, {"reuse_zeroed_memory", EnableOption::ReuseZeroedMemory}, + {"resize_scheduler", EnableOption::ResizeScheduler}, {"static_fusion_count", EnableOption::StaticFusionCount}, {"warn_register_spill", EnableOption::WarnRegisterSpill}, - {"io_to_lower_precision", EnableOption::IoToLowerPrecision}, - {"kernel_debug", EnableOption::KernelDebug}, - {"kernel_lineinfo", EnableOption::KernelLineInfo}, }; return available_options; } diff --git a/csrc/options.h b/csrc/options.h index 8d69719897c..6e313672a02 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -40,6 +40,7 @@ enum class DebugDumpOption { // TODO(wujingyue): name the following FusionIrSched FusionIr, //!< Dump the Fusion IR before lowering. This is the Fusion IR fed //!< to `KernelExecutor::compileFusion`. + FusionIrGraph, //!< Dump a GraphViz graph of the Fusion IR FusionIrMath, //!< Dump just the compute (math) part of the above `FusionIr` //!< for conciseness KernelIr, //!< Dump the compiler Kernel IR @@ -93,17 +94,18 @@ enum class EnableOption { FuseMatmul, //! Enable automatic fusion of matmul and linear ops FuseMultipleMatmuls, //! Allow fusing more than one matmul in a single kernel IdModel, //! Enable IdModel - KernelDb, //! Enable Kernel Database - KernelProfile, //! Enable intra-kernel performance profiling - MemoryPromotion, //! Enable promotion of memory types for non-pointwise ops - StaticFusionCount, //! Enable using single static count in kernel name - ReuseZeroedMemory, //! Re-use zeroed memory used for grid synchronization - WarnRegisterSpill, //! Enable warnings of register spill IoToLowerPrecision, //! Enable castInputOutputToLowerPrecision. #1889 explains //! why we disabled it by default. + KernelDb, //! Enable Kernel Database KernelDebug, //! Enable debug mode in nvrtc KernelLineInfo, //! Embed line info to compiled kernel, and dump the full CUDA //! C++ code + KernelProfile, //! Enable intra-kernel performance profiling + MemoryPromotion, //! Enable promotion of memory types for non-pointwise ops + ReuseZeroedMemory, //! Re-use zeroed memory used for grid synchronization + ResizeScheduler, //! Enable the resize scheduler + StaticFusionCount, //! Enable using single static count in kernel name + WarnRegisterSpill, //! Enable warnings of register spill EndOfOption //! Placeholder for counting the number of elements }; diff --git a/csrc/polymorphic_value.cpp b/csrc/polymorphic_value.cpp index e2c838ef6e8..ca929d4e982 100644 --- a/csrc/polymorphic_value.cpp +++ b/csrc/polymorphic_value.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include @@ -44,10 +45,7 @@ namespace PolymorphicValue_functions { std::string toString(const PolymorphicValue& v) { std::stringstream ss; if (v.is()) { - const auto& t = v.as(); - ss << "Tensor(sizes=" << t.sizes() << ", " - << "stride=" << t.strides() << ", dtype=" << t.dtype() - << ", device=" << t.device() << ", data_ptr=" << t.data_ptr() << ")"; + ss << debug_str(v.as()); } else if (v.is()) { ss << "std::monostate"; } else if (v.is()) { diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index 047095f79d8..9d62e0dc1a9 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -9,11 +9,11 @@ #include #include +#include #include #include #include #include -#include #include #include @@ -33,7 +33,7 @@ void insertReshardingsBefore(Fusion* fusion) { // Remove this after we refactor this as a pre-segmenter pass. FusionGuard fg(fusion); for (Expr* expr : fusion->exprs()) { - if (isLowerableToCommunication(expr) || shouldReshardAfter(expr)) { + if (HostIrLower::canLower(expr) || shouldReshardAfter(expr)) { continue; } @@ -85,7 +85,7 @@ void insertReshardingsAfter(Fusion* fusion) { auto exprs = fusion->exprs(); for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) { Expr* expr = *it; - if (isLowerableToCommunication(expr) || !shouldReshardAfter(expr)) { + if (HostIrLower::canLower(expr) || !shouldReshardAfter(expr)) { continue; } diff --git a/csrc/preseg_passes/move_pad.cpp b/csrc/preseg_passes/move_pad.cpp index e2773ac6ea2..3163a4613bd 100644 --- a/csrc/preseg_passes/move_pad.cpp +++ b/csrc/preseg_passes/move_pad.cpp @@ -305,47 +305,48 @@ TensorView* replayConcretePad( // returns padded inputs. When moving pad fails, this function returns an empty // vector. std::vector maybeMovePadBeforeDefinition( - TensorView* tv, + TensorView* pad_inp, const std::unordered_set& pad_dependencies, std::vector& stack, std::unordered_set simple_pad_set) { std::vector padded_inputs; // stop propagation if current PadOp p isn't the only use of tv, since // it requires tv to be live in the fusion. - if (tv->uses().size() != 1) { + if (pad_inp->uses().size() != 1) { return padded_inputs; } - Expr* expr = tv->definition(); + Expr* pad_inp_def = pad_inp->definition(); // stop propagation if any of expr's inputs are not TensorView, which we // cannot pad. - if (std::any_of(expr->inputs().begin(), expr->inputs().end(), [](Val* val) { - return !val->isA(); - })) { + if (std::any_of( + pad_inp_def->inputs().begin(), + pad_inp_def->inputs().end(), + [](Val* val) { return !val->isA(); })) { return padded_inputs; } // stop propagation if moving pad before definition would create cycles NVF_ERROR( - expr->outputs().size() == 1, + pad_inp_def->outputs().size() == 1, "expects tv to be the only output from its definition") - if (pad_dependencies.count(tv) > 0) { + if (pad_dependencies.count(pad_inp) > 0) { return padded_inputs; } - PadOp* p = tv->uses()[0]->as(); - padded_inputs.reserve(expr->inputs().size()); + PadOp* pad = pad_inp->uses()[0]->as(); + padded_inputs.reserve(pad_inp_def->inputs().size()); std::transform( - expr->inputs().begin(), - expr->inputs().end(), + pad_inp_def->inputs().begin(), + pad_inp_def->inputs().end(), std::back_inserter(padded_inputs), - [&p, &stack, &simple_pad_set](Val* val) { + [&pad, &stack, &simple_pad_set](Val* inp_of_pad_inp) { TensorView* new_pad_in = replayConcretePad( - val, - p->value(), - {p->getPadWidths()}, + inp_of_pad_inp, + pad->value(), + {pad->getPadWidths()}, TensorDomain::noReductions( - p->out()->as()->getLogicalDomain())); + pad->out()->as()->getLogicalDomain())); PadOp* new_pad_op = new_pad_in->definition()->as(); stack.push_back(new_pad_op); simple_pad_set.insert(new_pad_op); @@ -398,93 +399,93 @@ void propagatePads(Fusion* fusion) { std::vector pad_to_be_removed; while (!stack.empty()) { - PadOp* p = stack.back(); + PadOp* pad = stack.back(); stack.pop_back(); // if no uses, this has already been short-wired. - if (p->out()->uses().empty() && !p->out()->isFusionOutput()) { + if (pad->out()->uses().empty() && !pad->out()->isFusionOutput()) { continue; } // unify all consumer pad of tv; - auto* tv = p->in()->as(); - for (Expr* use : tv->uses()) { - if (use == p) { + auto* pad_inp = pad->in()->as(); + for (Expr* uses_of_pad_op_inp : pad_inp->uses()) { + if (uses_of_pad_op_inp == pad) { continue; } // check if use is the same pad operation (same pad value / width e.t.c.) - if (isSamePadOp(use, p)) { + if (isSamePadOp(uses_of_pad_op_inp, pad)) { // replace consumer of use->out() with p->out() ir_utils::replaceValInAllExprInputsAndFusionOutputs( - use->output(0), p->out()); + uses_of_pad_op_inp->output(0), pad->out()); // we could remove `use`, but `use` could still be in stack and needs to // be visited later. So push it to a vector and we'll remove it later. - pad_to_be_removed.push_back(use); + pad_to_be_removed.push_back(uses_of_pad_op_inp); } } // if tv is fusion output, we need to keep tv alive, it might render // propagating PadOp before tv->definition() being non-optimal. - if (tv->isFusionOutput()) { + if (pad_inp->isFusionOutput()) { continue; } // check for pad_dependencies to verify that 'p' can be moved before 'def'. std::unordered_set pad_inputs; - for (Val* val : p->inputs()) { - if (val == p->in() || val->isConst()) { + for (Val* pad_inp : pad->inputs()) { + if (pad_inp == pad->in() || pad_inp->isConst()) { continue; } - pad_inputs.insert(val); + pad_inputs.insert(pad_inp); } std::unordered_set pad_dependencies = DependencyCheck::getAllDependentVals(pad_inputs); - Expr* def = p->in()->definition(); + Expr* def_of_pad_in = pad->in()->definition(); Val* new_out = nullptr; - if (auto* uop = dynamic_cast(def)) { + if (auto* uop = dynamic_cast(def_of_pad_in)) { // check if unary op type is compatible for zero pad propagation. if (!zeroIsFixedPoint(uop->getUnaryOpType())) { continue; } - std::vector new_pad_inputs = maybeMovePadBeforeDefinition( - tv, pad_dependencies, std::ref(stack), std::ref(simple_pad_set)); + std::vector outputs_of_moved_pad = maybeMovePadBeforeDefinition( + pad_inp, pad_dependencies, std::ref(stack), std::ref(simple_pad_set)); // stop when move pad fails. - if (new_pad_inputs.empty()) { + if (outputs_of_moved_pad.empty()) { continue; } // update new outputs. - new_out = - ops::newValLike(new_pad_inputs[0], uop->out()->getDataType().value()); + new_out = ops::newValLike( + outputs_of_moved_pad[0], uop->out()->getDataType().value()); IrBuilder::create( - uop->getUnaryOpType(), new_out, new_pad_inputs[0]); - } else if (auto* bop = dynamic_cast(def)) { + uop->getUnaryOpType(), new_out, outputs_of_moved_pad[0]); + } else if (auto* bop = dynamic_cast(def_of_pad_in)) { // check if unary op type is compatible for zero pad propagation. if (!zeroIsIdentity(bop->getBinaryOpType())) { continue; } // check for broadcast on padded axis. - if (hasBroadcastOnAny(p->getPaddedAxes(), bop->inputs())) { + if (hasBroadcastOnAny(pad->getPaddedAxes(), bop->inputs())) { continue; } - std::vector new_pad_inputs = maybeMovePadBeforeDefinition( - tv, pad_dependencies, std::ref(stack), std::ref(simple_pad_set)); + std::vector outputs_of_moved_pad = maybeMovePadBeforeDefinition( + pad_inp, pad_dependencies, std::ref(stack), std::ref(simple_pad_set)); // stop when move pad fails. - if (new_pad_inputs.empty()) { + if (outputs_of_moved_pad.empty()) { continue; } - new_out = - ops::newOutputTV(new_pad_inputs, bop->out()->getDataType().value()); + new_out = ops::newValLike(pad->output(0), pad->output(0)->dtype()); + IrBuilder::create( bop->getBinaryOpType(), new_out, - new_pad_inputs[0], - new_pad_inputs[1]); + outputs_of_moved_pad[0], + outputs_of_moved_pad[1]); // insert new PadOp(s) to stack; - } else if (auto* pop = dynamic_cast(def)) { + } else if (auto* pop = dynamic_cast(def_of_pad_in)) { // stop propagation if PadOp `pop` isn't a simple PadOp, since we can // only merge simple PadOp together. Note that we don't need to check // the other uses of `tv` here, since we want to merge the consecutive @@ -497,20 +498,20 @@ void propagatePads(Fusion* fusion) { new_out = replayConcretePad( pop->in()->as(), pop->value(), - {pop->getPadWidths(), p->getPadWidths()}, + {pop->getPadWidths(), pad->getPadWidths()}, TensorDomain::noReductions( - p->out()->as()->getLogicalDomain())); + pad->out()->as()->getLogicalDomain())); // insert new PadOp(s) to stack; stack.push_back(new_out->definition()->as()); simple_pad_set.insert(new_out->definition()->as()); - } else if (def->isA()) { + } else if (def_of_pad_in->isA()) { // TODO: can cat support broadcast on any non-cat dimensions? Otherwise // we need to ensure that we are not padding on broadcast dimensions // like binary op // check if PadOp can be replayed on input(s) std::vector new_pad_inputs = maybeMovePadBeforeDefinition( - tv, pad_dependencies, std::ref(stack), std::ref(simple_pad_set)); + pad_inp, pad_dependencies, std::ref(stack), std::ref(simple_pad_set)); // stop when move pad fails. if (new_pad_inputs.empty()) { continue; @@ -520,7 +521,7 @@ void propagatePads(Fusion* fusion) { } // replace old (->pad->) with (->pads_before_new_def->new_def->) if (new_out != nullptr) { - ir_utils::replaceValInAllExprInputsAndFusionOutputs(p->out(), new_out); + ir_utils::replaceValInAllExprInputsAndFusionOutputs(pad->out(), new_out); } } diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 2ad82f9dd20..b4943f1c91e 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -25,6 +25,7 @@ #include #include #include +#include namespace nvfuser::preseg_passes { @@ -45,6 +46,9 @@ namespace nvfuser::preseg_passes { // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); + // This pass should be placed before ConsecutiveCastPass as more + // consecutive cast ops may be exposed by this pass + OptimizationPass::runPass(fusion); // removes consecutive cast operations OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/preseg_passes/reorder_sharded_axis.cpp b/csrc/preseg_passes/reorder_sharded_axis.cpp index 0a6d3765dd9..f6359cb424e 100644 --- a/csrc/preseg_passes/reorder_sharded_axis.cpp +++ b/csrc/preseg_passes/reorder_sharded_axis.cpp @@ -9,10 +9,10 @@ #include #include +#include #include #include #include -#include #include #include #include @@ -20,6 +20,8 @@ namespace nvfuser::preseg_passes { void ReorderShardedAxisPass::runPass(Fusion* fusion) { + FusionGuard fg(fusion); + const std::vector& exprs = fusion->exprs(); for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) { Expr* expr = *it; diff --git a/csrc/preseg_passes/translate_repeat_to_expand.cpp b/csrc/preseg_passes/translate_repeat_to_expand.cpp new file mode 100644 index 00000000000..382dcb85f52 --- /dev/null +++ b/csrc/preseg_passes/translate_repeat_to_expand.cpp @@ -0,0 +1,195 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include + +#include +#include + +namespace nvfuser::preseg_passes { + +namespace { + +struct RepetitionInfo { + // Input tensor that is repeated + TensorView* input_tv = nullptr; + // Repeated logical ID of the input tensor + IterDomain* repeated_id = nullptr; + // Tensors fed into the concat op + std::vector cat_inp_tvs; +}; + +// Translation algorithm overview: +// +// Step 1: Inspection. Traverses the given fusion and looks for a +// sequence of ops that correspond to a repeatition. See +// RepeatToExpandTranslator::inspect() for more details. +// +// Step 2: Apply the translation in a reverse topologial order. See +// RepeatToExpandTranslator::translate() for more details. +class RepeatToExpandTranslator { + public: + RepeatToExpandTranslator(Fusion* fusion) : fusion_(fusion) {} + + void run() { + inspect(); + translate(); + } + + private: + // Traverse through the fusion and gather all patterns of a pad + // followed by a concat. If a single concat op has multiple pad + // inputs that resize the same iter domain of the same input tensor, + // that must correspond to a repetition. + void inspect() { + const auto exprs = fusion_->exprs(); + + for (auto pad : ir_utils::filterByType(exprs)) { + auto pad_inp = pad->input(0)->as(); + auto pad_out = pad->output(0)->as(); + + // Not supported if there are multiple expanded logical IDs + IterDomain* out_padded_root_id = nullptr; + bool multiple_resizes_found = false; + for (const auto i : c10::irange(pad_out->getLogicalDomain().size())) { + auto out_logical_id = pad_out->getLogicalDomain().at(i); + auto resize = dynamic_cast(out_logical_id->definition()); + if (resize == nullptr) { + continue; + } + if (out_padded_root_id != nullptr) { + // Multiple IDs are resized. Not supported. + multiple_resizes_found = true; + break; + } + out_padded_root_id = resize->in(); + } + + if (multiple_resizes_found || out_padded_root_id == nullptr) { + // Unsupported pattern + break; + } + + auto inp_padded_id = PairwiseLogicalDomainMap(pad_inp, pad_out) + .mapConsumerToProducer() + .at(out_padded_root_id); + + // The padded tensor must be immediately used by a concat only + if (pad_out->uses().size() != 1 || !pad_out->uses().at(0)->isA()) { + continue; + } + + auto cat_op = pad_out->uses().at(0); + + // If other inputs to the same concat op are already found, make + // sure this path from the pad op is compatible with the known + // ops. + if (auto it = repeat_info_map_.find(cat_op); + it == repeat_info_map_.end()) { + RepetitionInfo info; + info.input_tv = pad_inp; + info.repeated_id = inp_padded_id; + info.cat_inp_tvs.push_back(pad_out); + repeat_info_map_.emplace(cat_op, info); + } else { + auto& info = repeat_info_map_.at(cat_op); + if (info.input_tv != pad_inp || info.repeated_id != inp_padded_id) { + // Invalid + repeat_info_map_.erase(cat_op); + continue; + } + info.cat_inp_tvs.push_back(pad_out); + } + } + + // Remove invalid entries + for (auto it = repeat_info_map_.begin(); it != repeat_info_map_.end();) { + Expr* concatenating_expr = it->first; + const RepetitionInfo& info = it->second; + // Make sure all inputs to concatenating_expr are detected + if (concatenating_expr->inputs().size() != info.cat_inp_tvs.size()) { + // Invalid + it = repeat_info_map_.erase(it); + continue; + } + ++it; + } + } + + // For each detected repetition: + // + // Step 1. Insert a broadcast ID immediately outside of the + // repeated ID + // Step 2. Expand the broadcast ID by the repetition factor + // Step 3. Flatten the expanded ID and the repeated ID + void translate() { + const auto exprs = fusion_->exprs(); + // Apply the translation in a reverse topological order. Since the + // output of the repetition is replaced, the use exprs of the + // output are replaced too, which may invalidate the inspected + // info invalid. + for (auto exprs_it = exprs.rbegin(); exprs_it != exprs.rend(); ++exprs_it) { + Expr* expr = *exprs_it; + auto repeat_info_map_it = repeat_info_map_.find(expr); + if (repeat_info_map_it == repeat_info_map_.end()) { + continue; + } + + const auto& info = repeat_info_map_it->second; + + if (info.cat_inp_tvs.size() < 2) { + continue; + } + + auto original_out_tv = expr->output(0)->as(); + + // Step 1 + auto inp_domain = + TensorDomain::noReductions(info.input_tv->getLogicalDomain()); + std::vector bcast_flags(inp_domain.size() + 1, false); + auto repeated_id_offset = std::distance( + inp_domain.begin(), + std::find(inp_domain.begin(), inp_domain.end(), info.repeated_id)); + bcast_flags.at(repeated_id_offset) = true; + auto broadcast_tv = broadcast(info.input_tv, bcast_flags); + NVF_ERROR((size_t)broadcast_tv->nDims() == inp_domain.size() + 1); + + // Step 2 + std::vector expanded_sizes( + bcast_flags.size(), IrBuilder::create(-1L)); + expanded_sizes.at(repeated_id_offset) = + IrBuilder::create((int64_t)info.cat_inp_tvs.size()); + auto expanded_tv = expand(broadcast_tv, expanded_sizes); + + // Step 3 + auto flattened_tv = + flatten(expanded_tv, repeated_id_offset, repeated_id_offset + 1); + + ir_utils::replaceValInAllExprInputsAndFusionOutputs( + original_out_tv, flattened_tv); + } + } + + private: + Fusion* fusion_ = nullptr; + // Map of concat exprs to their info about repetition + std::unordered_map repeat_info_map_; +}; + +} // namespace + +void TranslateRepeatToExpand::runPass(Fusion* fusion) { + FusionGuard fg(fusion); + RepeatToExpandTranslator translator(fusion); + translator.run(); +} + +} // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/translate_repeat_to_expand.h b/csrc/preseg_passes/translate_repeat_to_expand.h new file mode 100644 index 00000000000..bc5d3ed5b35 --- /dev/null +++ b/csrc/preseg_passes/translate_repeat_to_expand.h @@ -0,0 +1,54 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser::preseg_passes { + +// Translate concat-based repetitions to expand and reshape ops. +// +// For example, given the following fusion: +// +// t0 = [i0]; +// t1 = cat({t0, t0}, -1); +// +// It will be translated to: +// +// t0 = [i0] +// t2 = broadcast(t0, {true, false}); +// t3 = expand(t2, {2, i0}); +// t4 = reshape(t3, {2 * i0}); +// +// And all uses of t1 will be replaced by t4. This pattern commonly +// appears in RoPE, e.g., +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L136. +// While the resize scheduler should be able to handle these patterns for +// pointwise-only segments, it is currently limited to only pointwise +// fusions only. This translation should promote larger fusions +// as it is not specific to any surrounding ops. +// +// Note that there's a potential downside compared to handling cat ops +// directly. Since insertion of broadcast IDs is not represented as +// Fusion IR expressions, a fusion may have more disconnected ID +// graphs after the translation, which may cause a segmentation that +// could be avoided with the original fusion. See +// PresegTest.TranslateRepeatToExpand4 for a concrete example. +class TranslateRepeatToExpand + : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + static std::string name() { + return "TranslateRepeatToExpand"; + } +}; + +} // namespace nvfuser::preseg_passes diff --git a/csrc/python_frontend/fusion_definition.h b/csrc/python_frontend/fusion_definition.h index 6157704f86b..28fc4b8b484 100644 --- a/csrc/python_frontend/fusion_definition.h +++ b/csrc/python_frontend/fusion_definition.h @@ -6,14 +6,15 @@ */ // clang-format on #pragma once + #include #include #include +#include #include #include #include -#include namespace nvfuser::python_frontend { diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index c72d7e9f178..ef7717b95f8 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -5,10 +5,18 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include +#include +#include +#include #include #include +#include + +#include +#include +#include + #include #include #include @@ -28,16 +36,8 @@ #include #include #include -#include #include #include -#include -#include -#include - -#include -#include -#include namespace nvfuser::python_frontend { @@ -850,7 +850,8 @@ void initNvFuserPythonBindings(PyObject* module) { .value("inner_outer_persistent", SchedulerType::InnerOuterPersistent) .value("outer_persistent", SchedulerType::OuterPersistent) .value("transpose", SchedulerType::Transpose) - .value("expr_eval", SchedulerType::ExprEval); + .value("expr_eval", SchedulerType::ExprEval) + .value("resize", SchedulerType::Resize); nvfuser.def("compute_contiguity", computeContiguity); nvfuser.def("compute_tensor_descriptor", computeTensorDescriptor); @@ -3595,501 +3596,7 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("scale").none(true) = py::none(), py::return_value_policy::reference); - //! The ScedOperators class is a nested class of FusionDefinition to allow the - //! user to query the class for the list of schedule operators. - //! - //! Example: - //! help(FusionDefinition.SchedOperators) - //! - //! Additional operators are expected to be defined below as needed. - py::class_ nvf_sched( - fusion_def, "SchedOperators"); - nvf_sched.def(py::init()); - nvf_sched.def( - "to_string", - [](FusionDefinition::SchedOperators& self, Tensor tensor) { - // NOTE: For debugging purposes, print the state of TensorView - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - // Determine if tensor is a result from a reduction operation. - FusionDefinition* fd = self.fusion_definition; - TensorView* tv = - fd->getFusionState(tensor.index)->template as(); - return tv->toString(); - }, - py::arg("tensor")); - nvf_sched.def( - "user_schedule_ir", - [](FusionDefinition::SchedOperators& self) { - return self.fusion_definition->userScheduleIr(); - }, - py::return_value_policy::reference); - //! experimental API for multidevice support - nvf_sched.def( - "_set_device_mesh", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const DeviceMesh& mesh) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto tv = fd->getFusionState(tensor.index)->template as(); - tv->setDeviceMesh(mesh); - }, - py::arg("tensor"), - py::arg("mesh")); - nvf_sched.def( - "parallelize", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - int axis, - const ParallelType& parallel_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto tv = fd->getFusionState(tensor.index)->template as(); - tv->axis(axis)->parallelize(parallel_type); - }, - py::arg("tensor"), - py::arg("axis"), - py::arg("parallel_type")); - nvf_sched.def( - "merge", - [](FusionDefinition::SchedOperators& self, Tensor arg, int dim) { - FUSER_PERF_SCOPE("SchedOperators.merge"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto input_tv = - fd->getFusionState(arg.index)->template as(); - input_tv->merge(dim); - }, - py::arg("arg"), - py::arg("dim")); - auto reduction_factor_func = [](FusionDefinition::SchedOperators& self, - Tensor arg, - const std::vector& dims) -> Tensor { - FUSER_PERF_SCOPE("SchedOperators.reduction_factor"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(arg.index)->template as(); - TensorView* output_tv = input_tv->rFactor(dims); - return fd->addTensor(output_tv); - }; - nvf_sched.def( - "reduction_factor", - reduction_factor_func, - py::arg("arg"), - py::arg("dims")); - nvf_sched.def( - "rfactor", reduction_factor_func, py::arg("arg"), py::arg("dims")); - nvf_sched.def( - "reorder", - [](FusionDefinition::SchedOperators& self, - Tensor arg, - const std::unordered_map& old2new) { - FUSER_PERF_SCOPE("SchedOperators.reorder"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto input_tv = - fd->getFusionState(arg.index)->template as(); - input_tv->reorder(old2new); - }, - py::arg("arg"), - py::arg("old2new")); - nvf_sched.def( - "split", - [](FusionDefinition::SchedOperators& self, - Tensor arg, - int64_t dim, - int64_t factor, - bool inner_split) { - FUSER_PERF_SCOPE("SchedOperators.split"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto input_tv = - fd->getFusionState(arg.index)->template as(); - input_tv->split(dim, factor, inner_split); - }, - py::arg("arg"), - py::arg("dim"), - py::arg("factor"), - py::arg("inner_split") = true); - nvf_sched.def( - "set_allocation_as_loop", - [](FusionDefinition::SchedOperators& self, Tensor arg) { - FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto* tv = fd->getFusionState(arg.index)->template as(); - tv->setAllocationDomain(tv->getLoopDomain(), true); - }, - py::arg("arg")); - nvf_sched.def( - "cache_after", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const LoadStoreOpType& op_type, - const CacheOp& cache_op) -> Tensor { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(tensor.index)->template as(); - TensorView* output_tv = input_tv->cacheAfter(op_type, cache_op); - return fd->addTensor(output_tv); - }, - py::arg("tensor"), - py::arg("op_type") = LoadStoreOpType::Set, - py::arg("cache_op") = CacheOp::Unspecified); - nvf_sched.def( - "cache_before", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const LoadStoreOpType& op_type) -> Tensor { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(tensor.index)->template as(); - TensorView* output_tv = input_tv->cacheBefore(op_type); - return fd->addTensor(output_tv); - }, - py::arg("tensor"), - py::arg("op_type") = LoadStoreOpType::Set); - nvf_sched.def( - "cache_fork", - [](FusionDefinition::SchedOperators& self, Tensor tensor) -> Tensor { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(tensor.index)->template as(); - TensorView* output_tv = input_tv->cacheFork(); - return fd->addTensor(output_tv); - }, - py::arg("tensor")); - nvf_sched.def( - "set_memory_type", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const MemoryType& memory_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* tv = - fd->getFusionState(tensor.index)->template as(); - tv->setMemoryType(memory_type); - }, - py::arg("tensor"), - py::arg("memory_type")); - nvf_sched.def( - "transform_like", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const std::vector& selected_tensors) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - TensorView* reference_tv = - fd->getFusionState(tensor.index)->template as(); - - TransformPropagator propagator(reference_tv); - if (selected_tensors.empty()) { - // Propagate scheduler transformations on reference TensorView to the - // rest of the fusion. - MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator); - } else { - // Propagate scheduler transformations on reference TensorView to the - // subset of the fusion. - std::unordered_set selected_tv_set; - selected_tv_set.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::inserter(selected_tv_set, selected_tv_set.end()), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - SetSelector selector( - {selected_tv_set.begin(), selected_tv_set.end()}); - MaxLogicalDomainInfoSpanningTree(reference_tv, &selector) - .traverse(&propagator); - } - }, - py::arg("tensor"), - py::arg("selected_tensors") = std::vector()); - nvf_sched.def( - "parallelize_like", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - int64_t pos, - const std::vector& selected_tensors, - const std::unordered_set& selected_parallel_types, - bool propagate_padding) { - // Propagate the parallelization from the selected dimensions of the - // reference tensor to their corresponding dimensions in all selected - // tensors in the DAG. - // - // 1. Position `pos` means selecting all the dimensions - // [0, 1, ..., pos - 1]. pos = -1 means selecting all dimensions. - // 2. `selected_tvs` are selected tensors in the DAG. Empty - // `selected_tvs` means selecting all tensors in the fusion of - // `reference_tv`. - // 3. `selected_parallel_types` are the selected parallel types. Empty - // `selected_parallel_types` means selecting all parallel types. - - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - TensorView* reference_tv = - fd->getFusionState(tensor.index)->template as(); - - std::vector selected_tvs; - selected_tvs.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::back_inserter(selected_tvs), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - - nvfuser::scheduler_utils::parallelizeAllLike( - reference_tv, - pos, - selected_tvs, - selected_parallel_types, - propagate_padding); - }, - py::arg("tensor"), - py::arg("pos") = -1, - py::arg("selected_tensors") = std::vector(), - py::arg("selected_parallel_types") = std::unordered_set(), - py::arg("propagate_padding") = true); - nvf_sched.def( - "inline_most", - [](FusionDefinition::SchedOperators& self, - const std::vector& selected_tensors) { - // Inline to the right most allowed position for the selected tensors in - // the current fusion. - - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - - if (selected_tensors.empty()) { - nvfuser::inlineMost(); - } else { - std::vector selected_tvs; - selected_tvs.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::back_inserter(selected_tvs), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - nvfuser::inlineMost(selected_tvs); - } - }, - py::arg("selected_tensors") = std::vector()); - nvf_sched.def( - "inline_at", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - int64_t pos, - bool best_effort, - const std::vector& selected_tensors) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - TensorView* reference_tv = - fd->getFusionState(tensor.index)->template as(); - - if (selected_tensors.empty()) { - // Inline to the position corresponding to the reference position in - // the reference tensor for all tensors in the current fusion. - nvfuser::inlineAllAt(reference_tv, pos, best_effort); - } else { - // Inline to the position corresponding to the reference position in - // the reference tensor for selected tensors in the current fusion. - std::unordered_set selected_tvs; - selected_tvs.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::inserter(selected_tvs, selected_tvs.end()), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - - nvfuser::inlineSelectedAt( - selected_tvs, reference_tv, pos, best_effort); - } - }, - py::arg("tensor"), - py::arg("pos") = -1, - py::arg("best_effort") = false, - py::arg("selected_tensors") = std::vector()); - nvf_sched.def("tensors", [](FusionDefinition::SchedOperators& self) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - // Return all Tensors in FusionDefinition - return self.fusion_definition->tensors(); - }); - nvf_sched.def( - "is_reduction", - [](FusionDefinition::SchedOperators& self, Tensor tensor) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - // Determine if tensor is a result from a reduction operation. - FusionDefinition* fd = self.fusion_definition; - TensorView* tv = - fd->getFusionState(tensor.index)->template as(); - return ( - !tv->isFusionInput() && - std::any_of( - tv->getMaybeRootDomain().begin(), - tv->getMaybeRootDomain().end(), - [](IterDomain* id) { return id->isReduction(); }) && - !isResharding(tv->definition())); - }, - py::arg("tensor")); - nvf_sched.def( - "can_schedule", - [](FusionDefinition::SchedOperators& self, - const SchedulerType& scheduler_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - return self.fusion_definition->userSchedule()->canScheduleDebug( - scheduler_type); - }, - py::arg("scheduler_type")); - nvf_sched.def( - "find_compatible_schedulers", [](FusionDefinition::SchedOperators& self) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - std::vector valid_scheduler_types; - valid_scheduler_types.reserve(all_heuristics_in_priority_order.size()); - std::copy_if( - all_heuristics_in_priority_order.begin(), - all_heuristics_in_priority_order.end(), - std::back_inserter(valid_scheduler_types), - [sched = self.fusion_definition->userSchedule()]( - SchedulerType scheduler_type) { - return sched->canSchedule(scheduler_type); - }); - return valid_scheduler_types; - }); - nvf_sched.def( - "schedule", - [](FusionDefinition::SchedOperators& self, - const SchedulerType& scheduler_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - auto&& [can_schedule, error_msg] = - sched->canScheduleDebug(scheduler_type); - NVF_CHECK(can_schedule, error_msg); - sched->scheduleWithType(scheduler_type); - }, - py::arg("heuristic")); - nvf_sched.def("schedule", [](FusionDefinition::SchedOperators& self) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - sched->schedule(); - }); - nvf_sched.def( - "compute_pointwise_heuristics", - [](FusionDefinition::SchedOperators& self) -> PointwiseParams& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - HeuristicParams* parameters = - sched->computeHeuristics(SchedulerType::PointWise); - return *parameters->as(); - }, - py::return_value_policy::reference); - nvf_sched.def( - "compute_reduction_heuristics", - [](FusionDefinition::SchedOperators& self) -> ReductionParams& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - HeuristicParams* parameters = - sched->computeHeuristics(SchedulerType::Reduction); - return *parameters->as(); - }, - py::return_value_policy::reference); - nvf_sched.def( - "compute_matmul_heuristics", - [](FusionDefinition::SchedOperators& self) -> MatmulParams& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - HeuristicParams* parameters = - sched->computeHeuristics(SchedulerType::Matmul); - return *parameters->as(); - }, - py::return_value_policy::reference); - nvf_sched.def( - "schedule_hyperparameters", - [](FusionDefinition::SchedOperators& self) - -> scheduler_utils::SchedulerHyperParameters& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - auto scheduler_hyperparameters_entry = HeuristicDataCacheEntry< - HeuristicCompileTime::SchedulerHyperParameters>( - sched->data_cache.get(), []() { - return std::make_unique< - scheduler_utils::SchedulerHyperParameters>( - /*vectorize_factor=*/1, - /*unroll_factor=*/1, - /*threads_per_block_min=*/1, - /*threads_per_block_max=*/1); - }); - return scheduler_hyperparameters_entry.get(); - }, - py::return_value_policy::reference); + bindSchedule(fusion_def); } void cleanup() { diff --git a/csrc/python_frontend/python_bindings.h b/csrc/python_frontend/python_bindings.h index a698619eb4e..bd8f0347530 100644 --- a/csrc/python_frontend/python_bindings.h +++ b/csrc/python_frontend/python_bindings.h @@ -10,10 +10,13 @@ #include #include +#include #include namespace nvfuser::python_frontend { NVF_API void initNvFuserPythonBindings(PyObject* module); +void bindSchedule(py::class_& fusion_def); + NVF_API void cleanup(); } // namespace nvfuser::python_frontend diff --git a/csrc/python_frontend/schedule_bindings.cpp b/csrc/python_frontend/schedule_bindings.cpp new file mode 100644 index 00000000000..b77982711cb --- /dev/null +++ b/csrc/python_frontend/schedule_bindings.cpp @@ -0,0 +1,517 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser::python_frontend { + +void bindSchedule(py::class_& fusion_def) { + //! The SchedOperators class is a nested class of FusionDefinition to allow + //! the user to query the class for the list of schedule operators. + //! + //! Example: + //! help(FusionDefinition.SchedOperators) + //! + //! Additional operators are expected to be defined below as needed. + py::class_ nvf_sched( + fusion_def, "SchedOperators"); + nvf_sched.def(py::init()); + nvf_sched.def( + "to_string", + [](FusionDefinition::SchedOperators& self, Tensor tensor) { + // NOTE: For debugging purposes, print the state of TensorView + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + // Determine if tensor is a result from a reduction operation. + FusionDefinition* fd = self.fusion_definition; + TensorView* tv = + fd->getFusionState(tensor.index)->template as(); + return tv->toString(); + }, + py::arg("tensor")); + nvf_sched.def( + "user_schedule_ir", + [](FusionDefinition::SchedOperators& self) { + return self.fusion_definition->userScheduleIr(); + }, + py::return_value_policy::reference); + //! experimental API for multidevice support + nvf_sched.def( + "_set_device_mesh", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const DeviceMesh& mesh) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto tv = fd->getFusionState(tensor.index)->template as(); + tv->setDeviceMesh(mesh); + }, + py::arg("tensor"), + py::arg("mesh")); + nvf_sched.def( + "parallelize", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + int axis, + const ParallelType& parallel_type) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto tv = fd->getFusionState(tensor.index)->template as(); + tv->axis(axis)->parallelize(parallel_type); + }, + py::arg("tensor"), + py::arg("axis"), + py::arg("parallel_type")); + nvf_sched.def( + "merge", + [](FusionDefinition::SchedOperators& self, Tensor arg, int dim) { + FUSER_PERF_SCOPE("SchedOperators.merge"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto input_tv = + fd->getFusionState(arg.index)->template as(); + input_tv->merge(dim); + }, + py::arg("arg"), + py::arg("dim")); + auto reduction_factor_func = [](FusionDefinition::SchedOperators& self, + Tensor arg, + const std::vector& dims) -> Tensor { + FUSER_PERF_SCOPE("SchedOperators.reduction_factor"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* input_tv = + fd->getFusionState(arg.index)->template as(); + TensorView* output_tv = input_tv->rFactor(dims); + return fd->addTensor(output_tv); + }; + nvf_sched.def( + "reduction_factor", + reduction_factor_func, + py::arg("arg"), + py::arg("dims")); + nvf_sched.def( + "rfactor", reduction_factor_func, py::arg("arg"), py::arg("dims")); + nvf_sched.def( + "reorder", + [](FusionDefinition::SchedOperators& self, + Tensor arg, + const std::unordered_map& old2new) { + FUSER_PERF_SCOPE("SchedOperators.reorder"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto input_tv = + fd->getFusionState(arg.index)->template as(); + input_tv->reorder(old2new); + }, + py::arg("arg"), + py::arg("old2new")); + nvf_sched.def( + "split", + [](FusionDefinition::SchedOperators& self, + Tensor arg, + int64_t dim, + int64_t factor, + bool inner_split) { + FUSER_PERF_SCOPE("SchedOperators.split"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto input_tv = + fd->getFusionState(arg.index)->template as(); + input_tv->split(dim, factor, inner_split); + }, + py::arg("arg"), + py::arg("dim"), + py::arg("factor"), + py::arg("inner_split") = true); + nvf_sched.def( + "set_allocation_as_loop", + [](FusionDefinition::SchedOperators& self, Tensor arg) { + FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto* tv = fd->getFusionState(arg.index)->template as(); + tv->setAllocationDomain(tv->getLoopDomain(), true); + }, + py::arg("arg")); + nvf_sched.def( + "cache_after", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const LoadStoreOpType& op_type, + const CacheOp& cache_op) -> Tensor { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* input_tv = + fd->getFusionState(tensor.index)->template as(); + TensorView* output_tv = input_tv->cacheAfter(op_type, cache_op); + return fd->addTensor(output_tv); + }, + py::arg("tensor"), + py::arg("op_type") = LoadStoreOpType::Set, + py::arg("cache_op") = CacheOp::Unspecified); + nvf_sched.def( + "cache_before", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const LoadStoreOpType& op_type) -> Tensor { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* input_tv = + fd->getFusionState(tensor.index)->template as(); + TensorView* output_tv = input_tv->cacheBefore(op_type); + return fd->addTensor(output_tv); + }, + py::arg("tensor"), + py::arg("op_type") = LoadStoreOpType::Set); + nvf_sched.def( + "cache_fork", + [](FusionDefinition::SchedOperators& self, Tensor tensor) -> Tensor { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* input_tv = + fd->getFusionState(tensor.index)->template as(); + TensorView* output_tv = input_tv->cacheFork(); + return fd->addTensor(output_tv); + }, + py::arg("tensor")); + nvf_sched.def( + "set_memory_type", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const MemoryType& memory_type) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* tv = + fd->getFusionState(tensor.index)->template as(); + tv->setMemoryType(memory_type); + }, + py::arg("tensor"), + py::arg("memory_type")); + nvf_sched.def( + "transform_like", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const std::vector& selected_tensors) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + FusionDefinition* fd = self.fusion_definition; + TensorView* reference_tv = + fd->getFusionState(tensor.index)->template as(); + + TransformPropagator propagator(reference_tv); + if (selected_tensors.empty()) { + // Propagate scheduler transformations on reference TensorView to the + // rest of the fusion. + MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator); + } else { + // Propagate scheduler transformations on reference TensorView to the + // subset of the fusion. + std::unordered_set selected_tv_set; + selected_tv_set.reserve(selected_tensors.size()); + std::transform( + selected_tensors.begin(), + selected_tensors.end(), + std::inserter(selected_tv_set, selected_tv_set.end()), + [&fd](const Tensor& t) { + return fd->getFusionState(t.index)->template as(); + }); + SetSelector selector( + {selected_tv_set.begin(), selected_tv_set.end()}); + MaxLogicalDomainInfoSpanningTree(reference_tv, &selector) + .traverse(&propagator); + } + }, + py::arg("tensor"), + py::arg("selected_tensors") = std::vector()); + nvf_sched.def( + "parallelize_like", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + int64_t pos, + const std::vector& selected_tensors, + const std::unordered_set& selected_parallel_types, + bool propagate_padding) { + // Propagate the parallelization from the selected dimensions of the + // reference tensor to their corresponding dimensions in all selected + // tensors in the DAG. + // + // 1. Position `pos` means selecting all the dimensions + // [0, 1, ..., pos - 1]. pos = -1 means selecting all dimensions. + // 2. `selected_tvs` are selected tensors in the DAG. Empty + // `selected_tvs` means selecting all tensors in the fusion of + // `reference_tv`. + // 3. `selected_parallel_types` are the selected parallel types. Empty + // `selected_parallel_types` means selecting all parallel types. + + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + FusionDefinition* fd = self.fusion_definition; + TensorView* reference_tv = + fd->getFusionState(tensor.index)->template as(); + + std::vector selected_tvs; + selected_tvs.reserve(selected_tensors.size()); + std::transform( + selected_tensors.begin(), + selected_tensors.end(), + std::back_inserter(selected_tvs), + [&fd](const Tensor& t) { + return fd->getFusionState(t.index)->template as(); + }); + + nvfuser::scheduler_utils::parallelizeAllLike( + reference_tv, + pos, + selected_tvs, + selected_parallel_types, + propagate_padding); + }, + py::arg("tensor"), + py::arg("pos") = -1, + py::arg("selected_tensors") = std::vector(), + py::arg("selected_parallel_types") = std::unordered_set(), + py::arg("propagate_padding") = true); + nvf_sched.def( + "inline_most", + [](FusionDefinition::SchedOperators& self, + const std::vector& selected_tensors) { + // Inline to the right most allowed position for the selected tensors in + // the current fusion. + + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + FusionDefinition* fd = self.fusion_definition; + + if (selected_tensors.empty()) { + nvfuser::inlineMost(); + } else { + std::vector selected_tvs; + selected_tvs.reserve(selected_tensors.size()); + std::transform( + selected_tensors.begin(), + selected_tensors.end(), + std::back_inserter(selected_tvs), + [&fd](const Tensor& t) { + return fd->getFusionState(t.index)->template as(); + }); + nvfuser::inlineMost(selected_tvs); + } + }, + py::arg("selected_tensors") = std::vector()); + nvf_sched.def( + "inline_at", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + int64_t pos, + bool best_effort, + const std::vector& selected_tensors) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + FusionDefinition* fd = self.fusion_definition; + TensorView* reference_tv = + fd->getFusionState(tensor.index)->template as(); + + if (selected_tensors.empty()) { + // Inline to the position corresponding to the reference position in + // the reference tensor for all tensors in the current fusion. + nvfuser::inlineAllAt(reference_tv, pos, best_effort); + } else { + // Inline to the position corresponding to the reference position in + // the reference tensor for selected tensors in the current fusion. + std::unordered_set selected_tvs; + selected_tvs.reserve(selected_tensors.size()); + std::transform( + selected_tensors.begin(), + selected_tensors.end(), + std::inserter(selected_tvs, selected_tvs.end()), + [&fd](const Tensor& t) { + return fd->getFusionState(t.index)->template as(); + }); + + nvfuser::inlineSelectedAt( + selected_tvs, reference_tv, pos, best_effort); + } + }, + py::arg("tensor"), + py::arg("pos") = -1, + py::arg("best_effort") = false, + py::arg("selected_tensors") = std::vector()); + nvf_sched.def("tensors", [](FusionDefinition::SchedOperators& self) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + // Return all Tensors in FusionDefinition + return self.fusion_definition->tensors(); + }); + nvf_sched.def( + "is_reduction", + [](FusionDefinition::SchedOperators& self, Tensor tensor) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + // Determine if tensor is a result from a reduction operation. + FusionDefinition* fd = self.fusion_definition; + TensorView* tv = + fd->getFusionState(tensor.index)->template as(); + return ( + !tv->isFusionInput() && + std::any_of( + tv->getMaybeRootDomain().begin(), + tv->getMaybeRootDomain().end(), + [](IterDomain* id) { return id->isReduction(); }) && + !isResharding(tv->definition())); + }, + py::arg("tensor")); + nvf_sched.def( + "can_schedule", + [](FusionDefinition::SchedOperators& self, + const SchedulerType& scheduler_type) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + return self.fusion_definition->userSchedule()->canScheduleDebug( + scheduler_type); + }, + py::arg("scheduler_type")); + nvf_sched.def( + "find_compatible_schedulers", [](FusionDefinition::SchedOperators& self) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + std::vector valid_scheduler_types; + valid_scheduler_types.reserve(all_heuristics_in_priority_order.size()); + std::copy_if( + all_heuristics_in_priority_order.begin(), + all_heuristics_in_priority_order.end(), + std::back_inserter(valid_scheduler_types), + [sched = self.fusion_definition->userSchedule()]( + SchedulerType scheduler_type) { + return sched->canSchedule(scheduler_type); + }); + return valid_scheduler_types; + }); + nvf_sched.def( + "schedule", + [](FusionDefinition::SchedOperators& self, + const SchedulerType& scheduler_type) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + auto&& [can_schedule, error_msg] = + sched->canScheduleDebug(scheduler_type); + NVF_CHECK(can_schedule, error_msg); + sched->scheduleWithType(scheduler_type); + }, + py::arg("heuristic")); + nvf_sched.def("schedule", [](FusionDefinition::SchedOperators& self) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + sched->schedule(); + }); + nvf_sched.def( + "compute_pointwise_heuristics", + [](FusionDefinition::SchedOperators& self) -> PointwiseParams& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + HeuristicParams* parameters = + sched->computeHeuristics(SchedulerType::PointWise); + return *parameters->as(); + }, + py::return_value_policy::reference); + nvf_sched.def( + "compute_reduction_heuristics", + [](FusionDefinition::SchedOperators& self) -> ReductionParams& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + HeuristicParams* parameters = + sched->computeHeuristics(SchedulerType::Reduction); + return *parameters->as(); + }, + py::return_value_policy::reference); + nvf_sched.def( + "compute_matmul_heuristics", + [](FusionDefinition::SchedOperators& self) -> MatmulParams& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + HeuristicParams* parameters = + sched->computeHeuristics(SchedulerType::Matmul); + return *parameters->as(); + }, + py::return_value_policy::reference); + nvf_sched.def( + "schedule_hyperparameters", + [](FusionDefinition::SchedOperators& self) + -> scheduler_utils::SchedulerHyperParameters& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + auto scheduler_hyperparameters_entry = HeuristicDataCacheEntry< + HeuristicCompileTime::SchedulerHyperParameters>( + sched->data_cache.get(), []() { + return std::make_unique< + scheduler_utils::SchedulerHyperParameters>( + /*vectorize_factor=*/1, + /*unroll_factor=*/1, + /*threads_per_block_min=*/1, + /*threads_per_block_max=*/1); + }); + return scheduler_hyperparameters_entry.get(); + }, + py::return_value_policy::reference); +} + +} // namespace nvfuser::python_frontend diff --git a/csrc/runtime/executor.cpp b/csrc/runtime/executor.cpp index 8becb528951..04f86b1edd0 100644 --- a/csrc/runtime/executor.cpp +++ b/csrc/runtime/executor.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -304,6 +305,9 @@ void KernelExecutor::compile( NVF_ERROR( !fusion->outputs().empty(), "No output found for this kernel, aborting."); + createKernelId( + scheduler_type, fusion_id_, concrete_id_, runtime_id_, group_id_); + // TODO: refactor the options_ passed through options_.device = c10::Device(c10::DeviceType::CUDA, args.getDeviceIndex()); @@ -346,10 +350,21 @@ void KernelExecutor::compile( } } + if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) { + fusion->printMath(); + } + if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) { fusion->print(); - } else if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) { - fusion->printMath(); + } + + if (isDebugDumpEnabled(DebugDumpOption::FusionIrGraph)) { + std::stringstream file_name; + file_name << "__tmp_fusion_ir_graph_" << kernel_id_ << ".dot"; + IrGraphGenerator::print( + fusion, + file_name.str().c_str(), + IrGraphGenerator::DetailLevel::ComputeOnly); } //! Force index_type to int and disable magic zero if we detect that the @@ -418,8 +433,7 @@ void KernelExecutor::compile( for (const auto& hook : post_lowering_hooks_) { hook(kernel); } - createKernelId( - scheduler_type, fusion_id_, concrete_id_, runtime_id_, group_id_); + setUsedTVs(); if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) { diff --git a/csrc/runtime/executor_kernel_arg.cpp b/csrc/runtime/executor_kernel_arg.cpp index 42033184930..46940ec96ce 100644 --- a/csrc/runtime/executor_kernel_arg.cpp +++ b/csrc/runtime/executor_kernel_arg.cpp @@ -99,7 +99,11 @@ void KernelArgumentHolder::erase(const PolymorphicValue* arg_to_delete) { std::string KernelArgumentHolder::toString() const { std::stringstream ss; for (const auto& arg : arguments_) { - ss << *arg << "\n"; + if (arg->is()) { + ss << debug_str(arg->as()) << "\n"; + } else { + ss << *arg << "\n"; + } } return ss.str(); } diff --git a/csrc/runtime/fusion_cache_utils.cpp b/csrc/runtime/fusion_cache_utils.cpp index 13947ec21b4..a6130898a2f 100644 --- a/csrc/runtime/fusion_cache_utils.cpp +++ b/csrc/runtime/fusion_cache_utils.cpp @@ -42,7 +42,7 @@ ArgumentManager::ArgumentManager( } const std::unordered_map& ArgumentManager:: - getTensorMap() { + getTensorMap() const { return tensor_map_; } const PolymorphicValue* ArgumentManager::checkTensorMap(Val* v) { diff --git a/csrc/runtime/fusion_cache_utils.h b/csrc/runtime/fusion_cache_utils.h index 415dadeae0f..af27814b44a 100644 --- a/csrc/runtime/fusion_cache_utils.h +++ b/csrc/runtime/fusion_cache_utils.h @@ -91,7 +91,7 @@ class ArgumentManager { const RuntimeWorkSpace& runtime_workspace, const std::vector& fusion_inputs); - const std::unordered_map& getTensorMap(); + const std::unordered_map& getTensorMap() const; const PolymorphicValue* checkTensorMap(Val* v); @@ -104,6 +104,17 @@ class ArgumentManager { const T& group_runtime_outputs, const int64_t group_id); + std::string toString() const { + std::stringstream ss; + ss << "ArgumentManager {"; + for (auto entry : tensor_map_) { + ss << " " << entry.first->toString() << " : " + << PolymorphicValue_functions::toString(*entry.second) << std::endl; + } + ss << "}" << std::endl; + return ss.str(); + } + private: KernelArgumentHolder& fusion_args_; // map from val to args diff --git a/csrc/scheduler/ampere_multi_matmul.cpp b/csrc/scheduler/ampere_multi_matmul.cpp index d582e9e9a10..598482b76c9 100644 --- a/csrc/scheduler/ampere_multi_matmul.cpp +++ b/csrc/scheduler/ampere_multi_matmul.cpp @@ -992,9 +992,6 @@ void AmpereMultipleMatmulScheduler::schedulePrologues() { std::vector& mma_inputs, MmaOperand operand_type) { NVF_ERROR(smem_stores.size() == smem_loads.size()); - // TODO: we should not assume that each operand is used in only a single - // mma op - NVF_ERROR(mma_results_.size() >= smem_loads.size()); // We will save abs_ and bbs_ here for later use // TODO: save all register prologue tensors instead to a new vector called // prologue_register_tensors_ diff --git a/csrc/scheduler/compile_time_info.h b/csrc/scheduler/compile_time_info.h index f7ec9d4a97f..3436bcd70eb 100644 --- a/csrc/scheduler/compile_time_info.h +++ b/csrc/scheduler/compile_time_info.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -54,7 +55,7 @@ enum class CompileTimeEntryType { //! stores the domain map of a fusion. class DomainMap { public: - using DataType = pointwise_utils::DomainMap; + using DataType = scheduler_tools::DomainMap; static const CompileTimeEntryType EntryType = CompileTimeEntryType::DOMAIN_MAP; }; @@ -63,7 +64,7 @@ class DomainMap { //! stores the domain map of a fusion, used by transpose scheduler. class TransposeDomainMap { public: - using DataType = pointwise_utils::DomainMap; + using DataType = scheduler_tools::DomainMap; static const CompileTimeEntryType EntryType = CompileTimeEntryType::TRANSPOSE_DOMAIN_MAP; }; diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 91048d3374c..b0e4b751c8a 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -508,8 +508,12 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set); std::vector tvs_to_schedule{d, d_smem}; - if (std::find(mma_results_.begin(), mma_results_.end(), dc) == - mma_results_.end()) { + + bool dc_in_mma_results = + std::find(mma_results_.begin(), mma_results_.end(), dc) != + mma_results_.end(); + + if (!dc_in_mma_results) { // Skip scheduling dc if it is an mma_result. This can happen if we are // not casting back to half-precision in the output tvs_to_schedule.push_back(dc); @@ -519,34 +523,59 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { dc->setMemoryType(MemoryType::Local); d_smem->setMemoryType(MemoryType::Shared); - // Set LoadStoreOp - d_smem->definition()->as()->setOpType( - LoadStoreOpType::StMatrix); + auto store_with_stmatrix = dataTypeSize(dc->dtype()) == 2; + + if (store_with_stmatrix) { + // Set LoadStoreOp + d_smem->definition()->as()->setOpType( + LoadStoreOpType::StMatrix); + } d->definition()->as()->setOpType( LoadStoreOpType::CpAsyncBulkTensorTile); - // Block Schedule and Parallelize + // Apply the common transforms to dc, d_smem, d + // After these transforms we schedule the inner two non-reduction loops + // (instruction tile) of dc and propagate is back till the outputs of mma. blockTileTensors(tvs_to_schedule); parallelizeBlocks(tvs_to_schedule); - - // Apply mma common transformation for (auto tv : tvs_to_schedule) { transformLikeMmaOutput(tv, /*is_mma_result=*/false); } - // Schedule register cache; Output from epilogue - { + // Should not propagate if the dc is a mma output as the mma output has + // already been scheduled. + if (!dc_in_mma_results) { auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( dc->getLoopDomain()); dc->setLoopDomain(s.as()); dc->setAllocationDomain(s.as(), true); + + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + dc, + -1, + propagate_to, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); } MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem); - // Schedule shared memory cache; Output from StMatrix - mma_utils::scheduleStMatrixForMmaOutput( - d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n); + // [M, N] -> [128(TIDx), N/8 , m(2) , n(2)] + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + d_smem->getLoopDomain()); + if (swizzle != MmaInputSmemSwizzle::None) { + // Create tma store allocation domain with swizzle + mma_utils::scheduleTMAStoreForMmaOutput(d_smem, swizzle); + } + d_smem->setLoopDomain(s.as()); + + if (store_with_stmatrix) { + // Schedule shared memory cache; Output from StMatrix + mma_utils::scheduleStMatrixForMmaOutput( + d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n); + } + + d_smem->axis(-1)->parallelize(ParallelType::Vectorize); // Schedule global memory output; Output from TMA Store mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle); diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 7e8ee6dc4d7..c2b42dcf20b 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -216,6 +216,8 @@ class MatmulParams : public HeuristicParams { : "column-major") << "\n" << "Grid swizzle factor: " << grid_swizzle_factor << "\n" + << "Cluster dimensions: " << std::get<0>(cluster_dims) << " " + << std::get<1>(cluster_dims) << " " << std::get<2>(cluster_dims) << "\n" << "Use shared memory epilogue: " << use_smem_epilogue << "\n" << "Promote re-use of prologue shared memory: " << promote_prologue_smem_reuse << "\n" diff --git a/csrc/scheduler/matmul_heuristic_plugin.cpp b/csrc/scheduler/matmul_heuristic_plugin.cpp index 01333727841..ef0954f2185 100644 --- a/csrc/scheduler/matmul_heuristic_plugin.cpp +++ b/csrc/scheduler/matmul_heuristic_plugin.cpp @@ -141,6 +141,9 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) { setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile); setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile); setConfigTile(config->instruction_tile, getMmaOpShape(mparams->mma_macro)); + config->cluster_dims[0] = std::get<0>(mparams->cluster_dims); + config->cluster_dims[1] = std::get<1>(mparams->cluster_dims); + config->cluster_dims[2] = std::get<2>(mparams->cluster_dims); config->splitk_factor = mparams->splitk_factor; config->grid_swizzle_factor = mparams->grid_swizzle_factor; config->cta_order = @@ -163,6 +166,9 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) { }; setGemmTile(mparams->tile_sizes.cta_tile, config->cta_tile); setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile); + std::get<0>(mparams->cluster_dims) = config->cluster_dims[0]; + std::get<1>(mparams->cluster_dims) = config->cluster_dims[1]; + std::get<2>(mparams->cluster_dims) = config->cluster_dims[2]; mparams->circular_buffer_options.smem_circular_buffer_stage = config->load_stages; mparams->circular_buffer_options.smem_circular_buffer_prefetch_gap = diff --git a/csrc/scheduler/matmul_heuristic_plugin_api.h b/csrc/scheduler/matmul_heuristic_plugin_api.h index 207da96e9a8..1cd028b6a0a 100644 --- a/csrc/scheduler/matmul_heuristic_plugin_api.h +++ b/csrc/scheduler/matmul_heuristic_plugin_api.h @@ -72,6 +72,7 @@ struct KernelConfig { Tile cta_tile = {128, 128, 32}; Tile warp_tile = {64, 64, 32}; Tile instruction_tile = {16, 16, 16}; + Tile cluster_dims = {1, 1, 1}; uint16_t splitk_factor = 1; uint8_t load_stages = 2; // The circular buffering prefetch distance will be set to diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 415a28829c3..678ca85ba0f 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -98,7 +98,8 @@ void limitCircularBufferingSmemOperands( inline bool initCoreHeuristics( MatmulParams* mparams, const ProblemShape& problem_shape, - const mma_utils::TensorRolesMap& tensor_roles) { + const mma_utils::TensorRolesMap& tensor_roles, + const size_t num_problems) { const GemmTile instruction_tile = getMmaOpShape(mparams->mma_macro); GemmTile warp_tile = {-1, -1, -1}; GemmTile cta_tile = {-1, -1, -1}; @@ -113,7 +114,7 @@ inline bool initCoreHeuristics( // - start with [4, 4, 2] shape, later it should depend on problem // shape and have bigger impact on CTA tile shape - const DimType m_ratio = 4; + const DimType m_ratio = 4 / (DimType)num_problems; const DimType n_ratio = 4; const DimType k_ratio = 2; @@ -264,10 +265,11 @@ std::string isMatmulFusionDefinitionSupported( {MatmulTensorRole::OPERAND_A, MatmulTensorRole::OPERAND_B}) { auto entry = tensor_roles.find(role); if (entry != tensor_roles.end()) { - if (1 == entry->second.size()) { + if (isOptionEnabled(EnableOption::FuseMultipleMatmuls) || + 1 == entry->second.size()) { tvs_with_roles.insert(entry->second.begin(), entry->second.end()); } else { - return "There is other than one fusion input that can be MMA operand"; + return "There is more than one fusion input that can be MMA operand (enable fuse_multiple_matmuls)"; } } else { return "No candidate in fusion inputs for MMA operand"; @@ -370,10 +372,16 @@ class VectorizationCalculator { MatmulParams::SupportedVectorization compute() { const std::vector a_vecs = operandVectorizations(MatmulTensorRole::OPERAND_A); - NVF_ERROR(a_vecs.size() == 1, "Expected exactly one A operand"); + NVF_ERROR( + isOptionEnabled(EnableOption::FuseMultipleMatmuls) || + a_vecs.size() == 1, + "Expected exactly one A operand"); const std::vector b_vecs = operandVectorizations(MatmulTensorRole::OPERAND_B); - NVF_ERROR(b_vecs.size() == 1, "Expected exactly one B operand"); + NVF_ERROR( + isOptionEnabled(EnableOption::FuseMultipleMatmuls) || + b_vecs.size() == 1, + "Expected exactly one B operand"); return {a_vecs[0], b_vecs[0], epilogueVectorization()}; } @@ -703,8 +711,10 @@ std::unique_ptr getMatmulHeuristics( mma_utils::findMatmulPatterns(fusion); NVF_ERROR(!patterns.empty(), "No matmul patterns were found"); NVF_ERROR( - patterns.size() == 1, - "Only a single matmul pattern can currently be fused"); + isOptionEnabled(EnableOption::FuseMultipleMatmuls) || + patterns.size() == 1, + "Only a single matmul pattern can currently be fused ", + "unless the fuse_multiple_matmuls option is enabled"); mma_utils::MatmulPattern& pattern = patterns.front(); // IdModel is used to analyze problem shape & layout @@ -750,14 +760,21 @@ std::unique_ptr getMatmulHeuristics( problem_shape[(size_t)MatmulDimRole::Batch], inner_dims, tensor_roles); + // TODO: more sophisticated handling of multiple matmuls when using plugin + mparams->tile_sizes.cta_tile.m /= (int64_t)patterns.size(); } else { TORCH_WARN_ONCE( "Scheduling a matmul without heuristic plugin. " "Specify plugin location like this: " "NVFUSER_MATMUL_HEURISTIC_PLUGIN=/path/to/libmatmulheuristic.so"); // Populate heuristic details - auto status = - initCoreHeuristics(mparams.get(), problem_shape, tensor_roles); + auto status = initCoreHeuristics( + mparams.get(), + problem_shape, + tensor_roles, + // TODO: this assumes all patterns will lie in the same main loop, which + // might be false + /*num_problems=*/patterns.size()); NVF_ERROR(status, "Initialization of core part of heuristics failed."); } @@ -800,9 +817,10 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // scheduler. // 6. Check if the fusion is resharding. + const auto device_prop = at::cuda::getCurrentDeviceProperties(); + // #0 { - const auto device_prop = at::cuda::getCurrentDeviceProperties(); // Use a dummy problem shape to determine whether this is a supported // device. const auto mma_op = getMmaOp( @@ -824,6 +842,16 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { { for (const mma_utils::MatmulPattern& pattern : patterns) { Expr* op = pattern.output->definition(); + if (device_prop->major >= 9 && op->isA()) { + bool found_reduction = false; + for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { + if (found_reduction && + !pattern.output->axis((int64_t)dim)->isReduction()) { + return "Mul+Sum patterns can only be translated to MmaOp " + "on Hopper if the reduction dim is innermost"; + } + } + } if (op->isA() || op->isA()) { if (!isOptionEnabled(EnableOption::FuseMatmul)) { // Check for MatmulOp or LinearOp. If found, then only fuse if option @@ -846,7 +874,8 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { } } - if (patterns.size() > 1) { + if (!isOptionEnabled(EnableOption::FuseMultipleMatmuls) && + patterns.size() > 1) { return "Only a single matmul pattern can currently be fused"; } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index e9a24851a5a..3afc8d43a97 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -20,6 +20,7 @@ #include #include #include +#include "options.h" namespace nvfuser { @@ -187,7 +188,8 @@ TensorView* getOperandTv( NVF_ERROR(it != tensor_roles.end(), "Could not find any tensors with role"); const std::vector& operands = it->second; NVF_ERROR( - operands.size() == 1, + isOptionEnabled(EnableOption::FuseMultipleMatmuls) || + operands.size() == 1, "Exactly one operand is expected in each A and B role"); return operands.front(); } @@ -1309,19 +1311,9 @@ void scheduleStMatrixForMmaOutput( ((tile_m == 16 && tile_n == 16) || (tile_m == 16 && tile_n == 8)), "We only support 16x16 and 16x16 stmatrix now"); - NVF_ERROR( - tv->dtype() == DataType::Half, "we only support half type in stmatrix"); - - // [M, N] -> [128(TIDx), N/8 , 2 , 2] - auto s = - mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain()); - - if (swizzle != MmaInputSmemSwizzle::None) { - // Create tma store allocation domain with swizzle - mma_utils::scheduleTMAStoreForMmaOutput(tv, swizzle); - } - - tv->setLoopDomain(s.as()); + NVF_CHECK( + dataTypeSize(tv->dtype()) == 2, + "we only support 16-bit types in stmatrix"); if (tile_m == 16 && tile_n == 16) { // Let [M, N] be [64, 32] @@ -1341,16 +1333,18 @@ void scheduleStMatrixForMmaOutput( // [2, 128(TIDx), 2, 2] -> [2, 128(TIDx), 4(vectorize)] tv->merge(-2); } - tv->axis(-1)->parallelize(ParallelType::Vectorize); } MatmulOperandInnerDimsOpt getOperandInnerDims(Fusion* fusion) { const std::vector patterns = findMatmulPatterns(fusion); if (patterns.size() != 1) { - std::stringstream ss; - ss << "Invalid number of MmaOp instances in fusion, expected 1, got " - << patterns.size(); - return ss.str(); + if (!isOptionEnabled(EnableOption::FuseMultipleMatmuls)) { + std::stringstream ss; + ss << "Invalid number of MmaOp instances in fusion, expected 1, got " + << patterns.size(); + return ss.str(); + } + TORCH_WARN("TODO: Update getOperandInnerDims for multiple patterns"); } const MatmulPattern& pattern = patterns[0]; IdModel id_model(fusion); @@ -1777,130 +1771,294 @@ std::string MatmulPattern::toString() const { return ss.str(); } -MmaOp* MatmulPattern::translateToMmaOp() { - if (auto mma_op = dynamic_cast(output->definition())) { - // No translation needed - return mma_op; - } else if (output->definition()->isA()) { - Val* init = IrBuilder::create(0.0, output->dtype()); +namespace { + +// The `MatmulTranslator` helper class is used to map different matrix +// multiplication patterns to `MmaOp`. The `MmaOp` expression maps to the +// TensorCore ptx function. +// +// 1. `MmaOp` --This expression is what we need, so no changes required. +// 2. `ReductionOp` -- This expression corresponds with the sum operation in +// the `broadcast->multiply->sum` pattern. +// 3. `LinearOp` -- This expression is `y = w @ x + beta`, so it is replaced +// with `MmaOp` and pointwise `add`. +// 4. `MatmulOp` -- This expression is `y = A[M, K] @ B[K, N]`. The `MmaOp` +// expression requires `[M, N, K]` ordering, so it requires transposing the +// `B` operand. It also support batch matrix multiplication, which is +// tracked by `MmaOp::AxisMapping`. +// +// `finalizeMatmulOrLinearOp` +// * Fused-Multiply-Sum (FMS) is the output from MmaOp. +// * The output dtype can be different than the original output dtype. +// * This function casts the FMS TensorView to the original output dtype if +// necessary. +// +// `OptInDispatch` is used to catch any unsupported `MatmulPattern`. It is +// preferred to throw an error than to fallback to a sub-optimal default +// `MmaOp` translation. +class MatmulTranslator : public OptInDispatch { + public: + static MmaOp* translate(MatmulPattern& pattern, bool avoid_intermediates) { + MatmulTranslator trans(pattern, avoid_intermediates); + trans.dispatch(pattern.output->definition()); + return trans.mma_; + } + + private: + MatmulTranslator(MatmulPattern& pattern, bool avoid_intermediates) + : pattern_(pattern), avoid_intermediates_(avoid_intermediates) {} + + using OptInDispatch::handle; + + void handle(MmaOp* mma) final { + mma_ = mma; + } + + void handle(ReductionOp* rop) final { + Val* init = IrBuilder::create(0.0, pattern_.output->dtype()); // This replaces the mul and sum by overwriting output->definition() - return IrBuilder::create( - output, - A, - B, + mma_ = IrBuilder::create( + pattern_.output, + pattern_.A, + pattern_.B, init, - MmaOp::AxisMapping::trivialMapping(output->nDims())); + MmaOp::AxisMapping::trivialMapping(pattern_.output->nDims())); } - // This will hold the translated output from MatmulOp or LinearOp - TensorView* fms = nullptr; - MmaOp* mma_op = nullptr; - if (auto lop = dynamic_cast(output->definition())) { + void handle(LinearOp* lop) final { + // This will hold the translated output from MatmulOp or LinearOp + TensorView* fms = nullptr; // Linear takes inputs input, weight(, bias) - // - input can be any dimension > 0. We assert that it must be at least 2 - // and refuse to translate if dimension is 1. + // - input can be any dimension > 0. We assert that it must be at least + // 2 and refuse to translate if dimension is 1. // - weight can be one or two dimensional. We refuse to translate if // dimension is 1. // - bias, if present, can be zero or one dimensional. Bias can only be // present if weight is 2D // + // When A has dimension greater than two, all the preceding dimensions + // are essentially also M dimensions. The output is shaped like + // + // A [ ... iS0{M} iS1{K} ] + // B [ iS2{N} iS3{K} ] + // out [ ... iS3{M} iS3{N} rS3{K} ] + // // We translate by broadcasting input, weight, and bias such that the // contracted dimension K is in the last position (this is true of the // logical domains in input and weight already). Then we form an MmaOp and // optionally add the bias tensor followed by a cast back to the input // dtype. NVF_ERROR( - A->nDims() > 1 && B->nDims() > 1, + pattern_.A->nDims() > 1 && pattern_.B->nDims() > 1, "Cannot translate LinearOp with 1D input"); - std::vector bcast_dim((size_t)A->nDims() + 1, false); - bcast_dim[bcast_dim.size() - 2] = true; // N - A = broadcast(A, bcast_dim); + NVF_ERROR( + pattern_.B->nDims() == 2, + "Cannot translate LinearOp without 2D weight tensor"); + if (avoid_intermediates_) { + MmaOp::AxisMapping axis_mapping; + int64_t out_dim = pattern_.A->nDims() + 1L; + axis_mapping.a_axes.reserve(out_dim); + for (int64_t d : c10::irange(out_dim - 2L)) { + axis_mapping.a_axes.push_back(d); + } + axis_mapping.a_axes.reserve(out_dim); + for (size_t d : c10::irange(out_dim - 2)) { + axis_mapping.a_axes.push_back((int64_t)d); + } + axis_mapping.a_axes.push_back(-1); // missing N dimension + axis_mapping.a_axes.push_back(pattern_.A->nDims() - 1); // K dimension + + axis_mapping.b_axes.reserve(out_dim); + axis_mapping.b_axes.resize(out_dim, -1); + axis_mapping.b_axes[out_dim - 2] = 0; // N + axis_mapping.b_axes[out_dim - 1] = 1; // K + + int64_t num_M_dims = 1 + pattern_.A->nDims() - pattern_.B->nDims(); - bcast_dim[bcast_dim.size() - 2] = false; // reset N - std::fill(bcast_dim.begin(), bcast_dim.end() - 2, true); - B = broadcast(B, bcast_dim); + // Add loop broadcasts to A and B to mimic logical broadcasts for + // simpler scheduling + pattern_.A->broadcast(-2); // There's always a single N dimension - fms = fusedMultiplySum(A, B, {-1}); - mma_op = fms->definition()->as(); + for ([[maybe_unused]] size_t i : c10::irange((size_t)num_M_dims)) { + // Broadcast B for every M dimension in A + pattern_.B->broadcast(0); + } + + fms = fusedMultiplySum( + pattern_.A, pattern_.B, {-1}, /*init=*/nullptr, axis_mapping); + } else { + std::vector bcast_dim(pattern_.A->nDims() + 1, false); + bcast_dim[bcast_dim.size() - 2] = true; // N + pattern_.A = broadcast(pattern_.A, bcast_dim); + + bcast_dim[bcast_dim.size() - 2] = false; // reset N + std::fill(bcast_dim.begin(), bcast_dim.end() - 2, true); + pattern_.B = broadcast(pattern_.B, bcast_dim); + + fms = fusedMultiplySum(pattern_.A, pattern_.B, {-1}); + } + + mma_ = fms->definition()->as(); auto* bias = dynamic_cast(lop->bias()); if (bias != nullptr) { fms = add(fms, bias); } - } else if (output->definition()->isA()) { - // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so we - // must transpose B then broadcast both operands before creating the final - // op. + finalizeMatmulOpOrLinearOp(fms); + } + + void handle(MatmulOp* mop) final { + // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so + // we must transpose B then broadcast both operands before creating the + // final op. // // Also note that the output of MatmulOp is a tensor of shape [..., M, N] // whose dtype matches that of the inputs. We will most commonly then also // need to cast the output of the MmaOp to produce the output TensorView. + // + // There are two possibilities: + // + // Case 1: A->nDims() > B->nDims(): + // + // A [ ..., B1, ..., Bn, M, K ] + // B [ B1, ..., Bn, K, N ] + // + // All the preceding dimensions in A are additional M dimensions. There + // are batch dimensions in between those and "M". + // + // Case 2: A->nDims() <= B->nDims(): + // + // A [ B1, ..., Bn, M, K ] + // B [ ..., B1, ..., Bn, K, N ] + // + // All the preceding dimensions in B are additional N dimensions. There + // are batch dimensions in between those and "N". + // + // In either case, to form the output we transpose B in the last two dims, + // and prepend broadcasts to the lower dimensional input as needed. NVF_ERROR( - A->nDims() > 1 && B->nDims() > 1, + pattern_.A->nDims() > 1 && pattern_.B->nDims() > 1, "Cannot translate MatmulOp with 1D input"); - TensorView* Btrans = transpose(B, -2, -1); - A = unsqueeze(A, -2); - B = unsqueeze(Btrans, -3); - // A and B might have different dimensions. If so, broadcast the smaller one - // up to the size of the larger. - int64_t out_dims = std::max(A->nDims(), B->nDims()); - // Add new outer broadcast dimensions if necessary - A = ops::maybe_broadcast_inner_to_rank(A, out_dims); - B = ops::maybe_broadcast_inner_to_rank(B, out_dims); - fms = fusedMultiplySum(A, B, {-1}); - mma_op = fms->definition()->as(); - } else { - NVF_THROW( - "Could not translate matmul pattern with output ", - output->toString(), - " to MmaOp"); - } - NVF_ERROR(fms != nullptr); - NVF_ERROR(mma_op != nullptr); + TensorView* fms = nullptr; + if (avoid_intermediates_) { + MmaOp::AxisMapping axis_mapping; + int64_t out_dims = std::max(pattern_.A->nDims(), pattern_.B->nDims()) + 1; + + axis_mapping.a_axes.resize((size_t)out_dims, -1); + axis_mapping.b_axes.resize((size_t)out_dims, -1); + + for (size_t a_axis : c10::irange((size_t)pattern_.A->nDims() - 1)) { + // Output is [ ... M, N, K ] + // This loop maps everything but N and K to A + int64_t out_axis = + (int64_t)a_axis + (out_dims - 1 - pattern_.A->nDims()); + axis_mapping.a_axes.at((size_t)out_axis) = (int64_t)a_axis; + } + // Map the K dim, skipping one position + axis_mapping.a_axes.at((size_t)out_dims - 1) = pattern_.A->nDims() - 1; + + for (size_t b_axis : c10::irange((size_t)pattern_.B->nDims() - 2)) { + // Output is [ ... M, N, K ] + // This loop maps everything before M to B, skipping the output M dim + int64_t out_axis = + (int64_t)b_axis + (out_dims - pattern_.B->nDims()) - 1; + axis_mapping.b_axes.at((size_t)out_axis) = (int64_t)b_axis; + } + // Skip the K dim and map N and K + axis_mapping.b_axes.at((size_t)out_dims - 2) = pattern_.B->nDims() - 1; + axis_mapping.b_axes.at((size_t)out_dims - 1) = pattern_.B->nDims() - 2; - // The following is common to both MatmulOp and LinearOp translation + fms = fusedMultiplySum( + pattern_.A, pattern_.B, {-1}, /*init=*/nullptr, axis_mapping); + + int64_t num_M_dims = + std::max(1 + pattern_.A->nDims() - pattern_.B->nDims(), (int64_t)1); + + // Reorder to BMNK. + // Add loop broadcasts to A and B to mimick logical broadcasts for + // simpler scheduling + pattern_.A->broadcast(-2); - // TODO: skip downcasting if the only uses of `output` are casts back to - // higher precision in order avoid the round trip cast in defining an - // epilogue that starts with MatmulOp. - if (output->dtype() != fms->dtype()) { - // When fms is a different dtype from output, it means we _might_ need to - // insert a cast. However, we can skip inserting that cast for any uses of - // output that are simply casts back to Float. - - // This vector holds tensors that would be round-trip cast to the same - // dtype as fms. We first collect these Vals then we do the replacements - // separately in order to avoid dereferencing an Expr* that has already - // been replaced. - std::vector round_trip_vals; - for (Expr* use : output->uses()) { - if (auto* uop = dynamic_cast(use); uop != nullptr && - uop->getUnaryOpType() == UnaryOpType::Cast && - uop->out()->dtype() == fms->dtype()) { - round_trip_vals.push_back(uop->out()); + pattern_.B->reorder({{-2, -1}}); + for ([[maybe_unused]] size_t i : c10::irange((size_t)num_M_dims)) { + // Broadcast B for every M dimension in A + pattern_.B->broadcast(-3); } + } else { + TensorView* Btrans = transpose(pattern_.B, -2, -1); + pattern_.A = unsqueeze(pattern_.A, -2); + pattern_.B = unsqueeze(Btrans, -3); + // A and B might have different dimensions. If so, broadcast the smaller + // one up to the size of the larger. + int64_t out_dims = std::max(pattern_.A->nDims(), pattern_.B->nDims()); + // Add new outer broadcast dimensions if necessary + pattern_.A = ops::maybe_broadcast_inner_to_rank(pattern_.A, out_dims); + pattern_.B = ops::maybe_broadcast_inner_to_rank(pattern_.B, out_dims); + fms = fusedMultiplySum(pattern_.A, pattern_.B, {-1}); } - // If there are any uses that were not round-trip casts, then we should - // insert the castOp. - if (output->uses().size() > round_trip_vals.size()) { - TensorView* old_output = output; - output = castOp(output->dtype(), fms); - ir_utils::replaceValInAllExprInputsAndFusionOutputs(old_output, output); - } - // if any casts are skipped, then we reset output to point to the Float - // output fms instead of the downcast. - if (!round_trip_vals.empty()) { - output = fms; - } - // Finally, replace the round_trip_vals with fms - for (Val* v : round_trip_vals) { - ir_utils::replaceValInAllExprInputsAndFusionOutputs(v, fms); + mma_ = fms->definition()->as(); + finalizeMatmulOpOrLinearOp(fms); + } + + // The following is common to both MatmulOp and LinearOp translation + void finalizeMatmulOpOrLinearOp(TensorView* fms) { + NVF_ERROR(fms != nullptr); + NVF_ERROR(mma_ != nullptr); + + // TODO: skip downcasting if the only uses of `output` are casts back to + // higher precision in order avoid the round trip cast in defining an + // epilogue that starts with MatmulOp. + if (pattern_.output->dtype() != fms->dtype()) { + // When fms is a different dtype from output, it means we _might_ need + // to insert a cast. However, we can skip inserting that cast for any + // uses of output that are simply casts back to Float. + + // This vector holds tensors that would be round-trip cast to the same + // dtype as fms. We first collect these Vals then we do the replacements + // separately in order to avoid dereferencing an Expr* that has already + // been replaced. + std::vector round_trip_vals; + for (Expr* use : pattern_.output->uses()) { + if (auto* uop = dynamic_cast(use); uop != nullptr && + uop->getUnaryOpType() == UnaryOpType::Cast && + uop->out()->dtype() == fms->dtype()) { + round_trip_vals.push_back(uop->out()); + } + } + // If there are any uses that were not round-trip casts, then we should + // insert the castOp. + if (pattern_.output->uses().size() > round_trip_vals.size()) { + TensorView* old_output = pattern_.output; + pattern_.output = castOp(pattern_.output->dtype(), fms); + ir_utils::replaceValInAllExprInputsAndFusionOutputs( + old_output, pattern_.output); + } + // if any casts are skipped, then we reset output to point to the Float + // output fms instead of the downcast. + if (!round_trip_vals.empty()) { + pattern_.output = fms; + } + // Finally, replace the round_trip_vals with fms + for (Val* v : round_trip_vals) { + ir_utils::replaceValInAllExprInputsAndFusionOutputs(v, fms); + } + } else { + // No cast needed, for example the inputs might be Float + ir_utils::transferDefinitionToNewOutputs( + fms->definition(), {pattern_.output}); } - } else { - // No cast needed, for example the inputs might be Float - ir_utils::transferDefinitionToNewOutputs(fms->definition(), {output}); } - return mma_op; + + private: + MatmulPattern& pattern_; + bool avoid_intermediates_; + MmaOp* mma_ = nullptr; +}; + +} // namespace + +MmaOp* MatmulPattern::translateToMmaOp(bool avoid_intermediates) { + return MatmulTranslator::translate(*this, avoid_intermediates); } namespace { @@ -1986,17 +2144,18 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const { // for each valgroup, store a pair of flags. The first records whether the // group is present at all in the tv. The second records whether the value is // concrete (i.e. not reduction, broadcast, or device). - std::unordered_map> flags; + std::unordered_map flags; const auto recordPresence = [&graph, &flags]( TensorView* tv, size_t tensor_num) { for (IterDomain* id : tv->getLogicalDomain()) { const ValGroup& g = graph.toGroup(id); - auto& [present_flags, concrete_flags] = flags[g]; - present_flags.set(tensor_num); + DimPresence& group_flags = flags[g]; + // Note: broadcast or device dims will be initialized to have all false + // flags above if (id->isReduction() || id->isBroadcast() || id->isDeviceDim()) { continue; } - concrete_flags.set(tensor_num); + group_flags.set(tensor_num); } }; recordPresence(A, 0); @@ -2005,8 +2164,7 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const { DimRolesMap dim_roles; - for (const auto& [g, f] : flags) { - const auto& [present_flags, concrete_flags] = f; + for (const auto& [g, concrete_flags] : flags) { if (concrete_flags.all() || concrete_flags.none()) { // Batch dimensions are any of those that are not concretized or reduced. // These could be all Iteration or all Broadcast @@ -2019,9 +2177,25 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const { dim_roles[g] = MatmulDimRole::N; } else { NVF_THROW( - "IterDomain ValGroup should be present in at least two of A, B, output.", - " present_flags: ", - present_flags); + "IterDomain ValGroup should be concrete in at least two of A, B, output.", + " concrete_flags: ", + concrete_flags); + } + } + + // NOTE: For Hopper, we create loop broadcasts to mimic logical broadcasts + // when translating MatmulOp and LinearOp. Here we detect these and map them + // appropriately. + for (IterDomain* id : A->getLoopDomain()) { + const ValGroup& g = graph.toGroup(id); + if (dim_roles.count(g) == 0) { + dim_roles[g] = MatmulDimRole::N; + } + } + for (IterDomain* id : B->getLoopDomain()) { + const ValGroup& g = graph.toGroup(id); + if (dim_roles.count(g) == 0) { + dim_roles[g] = MatmulDimRole::M; } } diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index c88fe4926e3..6bef370240e 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -327,7 +327,11 @@ struct MatmulPattern { //! there is a MatmulOp instead, this function modifies the fusion to insert //! an MmaOp. TensorViews A and B are unchanged, but this->output might be //! updated to reflect the replacement tensor. - MmaOp* translateToMmaOp(); + //! + //! If avoid_intermediates is true, this function will use an + //! MmaOp::AxisMapping instead of broadcasting and permuting axes, in order to + //! avoid introducing unnecessary copies on Hopper and above. + MmaOp* translateToMmaOp(bool avoid_intermediates = false); //! Given an IdModel, map groups of IterDomains to dimension roles //! (MatmulDimRole). Note that ValGroup is a shared_ptr to a diff --git a/csrc/scheduler/multi_matmul.cpp b/csrc/scheduler/multi_matmul.cpp index 915e08ab8e8..33b350fd467 100644 --- a/csrc/scheduler/multi_matmul.cpp +++ b/csrc/scheduler/multi_matmul.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include #include @@ -21,7 +22,23 @@ void MultipleMatmulScheduler::findPatterns() { void MultipleMatmulScheduler::translatePatterns() { mma_results_.reserve(patterns_.size()); for (mma_utils::MatmulPattern& pattern : patterns_) { - MmaOp* mma = pattern.translateToMmaOp(); + // TODO: properly handle all mul+sum patterns for Hopper. For now, these + // should work fine as long as the inner dimensions are the ones being + // reduced. + if (!isAmpere(params_->mma_macro) && !isTuring(params_->mma_macro) && + pattern.output->definition()->isA()) { + bool found_reduction = false; + for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { + NVF_ERROR( + !found_reduction || + !pattern.output->axis((int64_t)dim)->isReduction(), + "Mul+Sum patterns can only be translated on Hopper if the reduction dim is innermost"); + } + } + + MmaOp* mma = pattern.translateToMmaOp( + /*avoid_intermediates=*/!isAmpere(params_->mma_macro) && + !isTuring(params_->mma_macro)); mma_results_.push_back(mma->out()->as()); } diff --git a/csrc/scheduler/no_op.cpp b/csrc/scheduler/no_op.cpp index d75a171d00c..a7eb6e2de1f 100644 --- a/csrc/scheduler/no_op.cpp +++ b/csrc/scheduler/no_op.cpp @@ -7,8 +7,8 @@ // clang-format on #include +#include #include -#include #include #include #include @@ -48,7 +48,7 @@ bool NoOpScheduler::canScheduleCompileTime(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); if (exprs.size() == 1 && isResharding(exprs[0]) && - isLowerableToCommunication(exprs[0])) { + HostIrLower::canLower(exprs[0])) { return true; } diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index bc7a0fb32c6..277f6890570 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -29,37 +29,6 @@ namespace { // Unused at the moment, commenting for clang tidy constexpr int64_t kThreadX = 128; -class DomainMap : public pointwise_utils::DomainMap { - public: - using pointwise_utils::DomainMap::DomainMap; - - // The pointwise scheduler heuristics requires a minimum number of axes. - // The output reference tensor should respect this requirement. - TensorView* findReferenceTensorView(int64_t minimum_num_axes = 0) const { - TensorView* result = nullptr; - int64_t max_dims = -1; - for (auto output_tv : - ir_utils::filterByType(fusion_->outputs())) { - if (isValidReference(output_tv) && - hasMinimumSize(output_tv, minimum_num_axes) && - !output_tv->isFusionInput()) { - int64_t n_dims = pointwise_utils::nRootDims(output_tv); - if (n_dims > max_dims) { - result = output_tv; - max_dims = n_dims; - } - } - } - return result; - } - - private: - bool hasMinimumSize(TensorView* tv, int64_t num_axes) const { - NVF_ERROR(tv != nullptr); - return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes); - } -}; - } // namespace std::unique_ptr getPointwiseHeuristics( @@ -79,14 +48,17 @@ std::unique_ptr getPointwiseHeuristics( auto domain_map_entry = HeuristicDataCacheEntry( - data_cache, - [fusion]() { return std::make_unique(fusion); }); - const auto& domain_map = dynamic_cast(domain_map_entry.get()); + data_cache, [fusion]() { + return std::make_unique( + fusion); + }); + const auto& domain_map = dynamic_cast( + domain_map_entry.get()); auto largest_out_entry = HeuristicDataCacheEntry( data_cache, [&domain_map]() { - std::vector data{domain_map.findReferenceTensorView()}; + std::vector data{domain_map.findReferenceTensor()}; return std::make_unique>(std::move(data)); }); TensorView* largest_out = largest_out_entry.get()[0]; @@ -432,19 +404,11 @@ std::unique_ptr getPointwiseHeuristics( return params; } -// Return reference tensor view. -TensorView* getReferenceTensorView(Fusion* fusion) { - FusionGuard fg(fusion); - DomainMap domain_map(fusion); - auto reference_tv = domain_map.findReferenceTensorView(); - return reference_tv; -} - //! Utility for canSchedule interface to check if this fusion has //! a fully broadcasted reference tensor, which is necessary for //! the pointwise scheduler. bool hasReferenceTensorView(Fusion* fusion) { - return getReferenceTensorView(fusion) != nullptr; + return pointwise_utils::getReferenceTensor(fusion) != nullptr; } bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) { @@ -529,11 +493,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { int64_t max_dims = 0; for (auto inp : input_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); + max_dims = std::max(scheduler_utils::nLogicalDims(inp), max_dims); } for (auto out : output_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); + max_dims = std::max(scheduler_utils::nLogicalDims(out), max_dims); } // If everything is zero dim tensors, just return. @@ -541,7 +505,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { return; } - TensorView* reference_tv = getReferenceTensorView(fusion); + TensorView* reference_tv = pointwise_utils::getReferenceTensor(fusion); NVF_ERROR( reference_tv != nullptr, diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 2f4f119fc46..054badfc36a 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -5,247 +5,16 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include -#include #include -#include - -#include namespace nvfuser { namespace pointwise_utils { -namespace { - -// Grab all exact set mappings from consumer to producer domains of -// indexed accesses, e.g., index_select -std::unordered_multimap< - std::shared_ptr>, - std::shared_ptr>> -getIndexedConsumerToProducerMap(Fusion* fusion, const ComputeAtMap& ca_map) { - std::unordered_multimap< - std::shared_ptr>, - std::shared_ptr>> - indexed_id_map; - - for (auto expr : fusion->exprs()) { - if (auto gather = dynamic_cast(expr)) { - auto p_id = gather->getIndexedID(); - auto c_id = gather->getConsumerOfIndexedID(); - indexed_id_map.emplace( - ca_map.disjointSetOf(c_id, IdMappingMode::EXACT), - ca_map.disjointSetOf(p_id, IdMappingMode::EXACT)); - } else if (auto index_select = dynamic_cast(expr)) { - auto p_id = index_select->getIndexedID(); - auto c_id = index_select->getConsumerOfIndexedID(); - indexed_id_map.emplace( - ca_map.disjointSetOf(c_id, IdMappingMode::EXACT), - ca_map.disjointSetOf(p_id, IdMappingMode::EXACT)); - } else { - // Note there's no consumer ID for select. This means we can't - // just propagate from consumers to indexed producers. It seems - // it's necessary to schedule producers and consumers separately - // in those cases. - continue; - } - } - - return indexed_id_map; -} - -// Check if a root ID of a fusion input tensor that is indirectly -// accessed by ops such as torchGather needs to be mapped with -// a reference tensor. Select has a similar effect as squeeze as the -// indexed domain is removed, so the domain does not need to be mapped -// as long as the tensor is a fusion input. Similarly, in index_select -// and torchGather, if the output domain is a broadcast, it does not -// need to be mapped if not resolved. -bool canIgnoreIndexedInputDomainID( - TensorView* input_tv, - IterDomain* root_id, - const ComputeAtMap& ca_map) { - NVF_ERROR(input_tv->isFusionInput()); - for (auto use : input_tv->uses()) { - if (auto select = dynamic_cast(use)) { - if (root_id != select->getIndexedID()) { - return false; - } - } else if (auto index_select = dynamic_cast(use)) { - // If the root_id is an indexed ID, and the consumer ID may be a - // broadcast. In that case, nothing needs to be mapped if the - // consumer broadcast is not resolved - if (root_id != index_select->getIndexedID() || - !ca_map - .getConcreteMappedID( - index_select->getConsumerOfIndexedID(), - IdMappingMode::PERMISSIVE) - ->isBroadcast()) { - return false; - } - } else if (auto gather = dynamic_cast(use)) { - // TODO: Remove this. Once slice is used for torchGather, this - // should not be necessary. For now, it is necessary to not - // break the existing torchGather tests - if (!gather->exactSizes()) { - continue; - } - // If the root_id is an indexed ID, and the consumer ID may be a - // broadcast. In that case, nothing needs to be mapped if the - // consumer broadcast is not resolved - if (root_id != gather->getIndexedID() || - !ca_map - .getConcreteMappedID( - gather->getConsumerOfIndexedID(), IdMappingMode::PERMISSIVE) - ->isBroadcast()) { - return false; - } - } else { - // If the input TV is used by any other ops - return false; - } - } - - return true; -} - -} // namespace - -DomainMap::DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(fusion) { - tvs_with_rfactor_ = scheduler_utils::getTVsWithNonReductionRFactor(fusion); -} - -// Determine if all IterDomains in input are mapped to the given tensor -bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) - const { - // Get concrete IDs for input root or logical domain - std::unordered_set in_concrete_ids; - for (auto in_id : input_tv->getLogicalDomain()) { - if (canIgnoreIndexedInputDomainID(input_tv, in_id, ca_map_)) { - continue; - } - - // Permissive map is required for the transpose scheduler to support cases - // like T0[I0, b] + T1[b, I1] - auto concrete = - ca_map_.getConcreteMappedID(in_id, IdMappingMode::PERMISSIVE); - - if (!concrete->isBroadcast() && !in_id->isReduction()) { - in_concrete_ids.insert(concrete); - } - } - - // Erase all input concrete IDs mapped to the output domain - // Ignore unresolved broadcast dimensions - eraseifInputMappedThroughRootDomainAndIndexing( - in_concrete_ids, tv->getLogicalDomain()); - - return in_concrete_ids.empty(); -} - -// Reference domains must exactly match with the input domains. See -// also PR #661 -IterDomain* DomainMap::getMappedInputConcreteID( - const std::unordered_set& in_concrete_ids, - IterDomain* out_id) const { - auto in_concrete_id_iter = std::find_if( - in_concrete_ids.begin(), - in_concrete_ids.end(), - [&](IterDomain* in_concrete_id) { - return ca_map_.areMapped(in_concrete_id, out_id, IdMappingMode::EXACT); - }); - if (in_concrete_id_iter != in_concrete_ids.end()) { - return *in_concrete_id_iter; - } else { - return nullptr; - } -} - -// Erase input concrete ID if it is mapped to output ID -bool DomainMap::eraseIfMapped( - std::unordered_set& in_concrete_ids, - IterDomain* out_id) const { - auto mapped_input_conrete_id = - getMappedInputConcreteID(in_concrete_ids, out_id); - if (mapped_input_conrete_id != nullptr) { - in_concrete_ids.erase(mapped_input_conrete_id); - return true; - } else { - return false; - } -} - -void DomainMap::eraseifInputMappedThroughRootDomainAndIndexing( - std::unordered_set& in_ids, - const std::vector& ids) const { - // Use ComputeAtMap::getAllDisjointSetProducers to grab all producer - // IDs through rfactor exprs - VectorOfUniqueEntries>> - exact_sets; - std::for_each(ids.begin(), ids.end(), [&](IterDomain* id) { - exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); - }); - - // Traverse through indexed domains. - const auto indexed_id_multimap = - getIndexedConsumerToProducerMap(fusion_, ca_map_); - - VectorOfUniqueEntries>> - all_exact_sets_covered; - - // Back traverses through the exact map and indexed - // producer-consumer pairs - for (auto current_sets = exact_sets; !current_sets.empty();) { - auto producer_sets = ca_map_.getAllDisjointSetProducers(current_sets); - all_exact_sets_covered.pushBack(producer_sets); - - current_sets.clear(); - - // Further traversal if any of the new producer sets is a producer - // of indexed domains - for (const auto& producer_set : producer_sets) { - auto indexed_id_multimap_range = - indexed_id_multimap.equal_range(producer_set); - for (auto producer_of_producer_it = indexed_id_multimap_range.first; - producer_of_producer_it != indexed_id_multimap_range.second; - ++producer_of_producer_it) { - current_sets.pushBack(producer_of_producer_it->second); - } - } - } - - for (const auto& exact_set_ptr : all_exact_sets_covered) { - auto exact_concrete_id = ca_map_.getConcreteMappedID( - exact_set_ptr->front(), IdMappingMode::EXACT); - eraseIfMapped(in_ids, exact_concrete_id); - } -} - -// Find any id in domain that maps with target id -IterDomain* DomainMap::anyMapped( - const std::vector& domain, - IterDomain* target) const { - for (auto id : domain) { - if (ca_map_.areMapped(id, target, IdMappingMode::EXACT)) { - return id; - } - } - return nullptr; -} - -// Determine if output TensorView is a valid reference tensor for this fusion. -// The reference tensor must map to all the iterDomains in each input. -bool DomainMap::isValidReference(TensorView* tv) const { - for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { - if (input_tv->uses().empty()) { - continue; - } - // TODO: Same backward traversal from tv is done for all input - // tvs. Consider doing the analysis one for all inputs - if (!areAllInputIdsMappedTo(input_tv, tv)) { - return false; - } - } - return true; +TensorView* getReferenceTensor(Fusion* fusion) { + FusionGuard fg(fusion); + scheduler_tools::PointwiseDomainMap domain_map(fusion); + auto reference_tv = domain_map.findReferenceTensor(); + return reference_tv; } } // namespace pointwise_utils diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 56db0ee0806..f9263ed343a 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -11,68 +11,14 @@ #include #include #include +#include #include namespace nvfuser { namespace pointwise_utils { -// DomainMap uses the ComputeAtMap to find a reference TensorView -// that maps to all IterDomains in the fusion. -class DomainMap { - public: - DomainMap(Fusion* fusion); - virtual ~DomainMap() = default; - - const ComputeAtMap& getComputeAtMap() const { - return ca_map_; - } - - // Determine if a TensorView is a valid reference tensor for this fusion. - // The reference tensor must map to all the iterDomains in each input. - bool isValidReference(TensorView* tv) const; - - protected: - // Determine if all IterDomains are mapped between input and the given tvs - bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) - const; - - virtual IterDomain* getMappedInputConcreteID( - const std::unordered_set& in_concrete_ids, - IterDomain* out_id) const; - - // Erase input concrete ID if it is mapped to output ID - bool eraseIfMapped( - std::unordered_set& in_concrete_ids, - IterDomain* out_id) const; - - // Check if in_ids are mapped to ids through any root domain as - // well as indirectly accessed domains with ops like torchGather - void eraseifInputMappedThroughRootDomainAndIndexing( - std::unordered_set& in_ids, - const std::vector& ids) const; - - // Find any id in domain that maps with target id - IterDomain* anyMapped( - const std::vector& domain, - IterDomain* target) const; - - Fusion* fusion_ = nullptr; - ComputeAtMap ca_map_; - std::vector tvs_with_rfactor_; -}; - -// Returns number of non-reduction/non-broadcas/non-device dims in logical -// domain -inline int64_t nRootDims(const TensorView* tv) { - auto logical_dom = tv->getLogicalDomain(); - int64_t tv_n_dims = 0; - for (auto dim : logical_dom) { - if (!dim->isReduction() && !dim->isBroadcast() && !dim->isDeviceDim()) { - tv_n_dims++; - } - } - return tv_n_dims; -} +// Return reference tensor view. +TensorView* getReferenceTensor(Fusion* fusion); } // namespace pointwise_utils } // namespace nvfuser diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index fd9573d6b32..039d94b38af 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -90,6 +91,8 @@ std::unique_ptr SchedulerEntry::makeSchedulerInstance( return std::make_unique(); case SchedulerType::ExprEval: return std::make_unique(); + case SchedulerType::Resize: + return std::make_unique(); default: NVF_THROW("unreachable"); } diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp new file mode 100644 index 00000000000..41277250b16 --- /dev/null +++ b/csrc/scheduler/resize.cpp @@ -0,0 +1,255 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser { + +namespace { + +// Just use the pointwise version for now +TensorView* getReferenceTensor(Fusion* fusion) { + return pointwise_utils::getReferenceTensor(fusion); +} + +} // namespace + +bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { + if (!isOptionEnabled(EnableOption::ResizeScheduler)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Not enabled"); + return false; + } + + if (!ir_utils::hasOpsOfType(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "No resize op to schedule"); + return false; + } + + if (scheduler_utils::isResharding(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Fusion is resharding."); + return false; + } + + if (ir_utils::hasAnyReductionOps(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "No support for reduction ops"); + return false; + } + + if (registry_utils::hasNonUniqueBcast(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Broadcasting dimension might be broadcasting to multiple sizes."); + return false; + } + + // For now, the resize scheduler is only allowed for a limited set + // of fusion patterns. The restrictions are planned to be + // incrementally relaxed. + + IdModel id_model(fusion, /*build_graphs=*/false); + const auto& broadcast_graph = id_model.buildBroadcastGraph(); + + auto resize_tensor_ops = ir_utils::getOpsOfType(fusion); + + // Slicing of or to a broadcast ID is not allowed yet. + for (auto resize_tensor_op : resize_tensor_ops) { + TensorView* out_tv = resize_tensor_op->output(0)->as(); + for (auto logical_id : out_tv->getLogicalDomain()) { + Resize* resize = dynamic_cast(logical_id->definition()); + if (resize == nullptr) { + continue; + } + + if (resize->out()->isBroadcast()) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Resize to a broadcast ID is not allowed."); + return false; + } + + // Need to check the broadcast group rather than just the input + // ID only. For example, + // + // t0: [i0] + // t1: [b1] + // t2 = t0 + t1 + // t3 = slice(t2) + // + // Then, propagating the slice to its inputs would try to + // propagate the resize op to b1 as well, which would fail due + // to issue #3571 + const auto& input_group = broadcast_graph.toGroup(resize->in()); + if (std::any_of( + input_group->begin(), input_group->end(), [](Val* inp_val) { + return inp_val->as()->isBroadcast(); + })) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Resize of a broadcast ID is not allowed."); + return false; + } + } + } + + // This doesn't work yet due to issue #3571 + auto ref_tv = getReferenceTensor(fusion); + if (std::any_of( + ref_tv->getLogicalDomain().begin(), + ref_tv->getLogicalDomain().end(), + [](IterDomain* logical_id) { return logical_id->isBroadcast(); })) { + return false; + } + + // Having different resizes between outputs is not allowed at this + // moment. For example, consider a fusion like: + // + // t0 = [i0] + // fusion.addInput(t0) + // t1 = t0[:i0/2] + // t2 = t0[i0/2:] + // fusion.addOutput(t1) + // fusion.addOutput(t2) + // + // For now, this is not going to be fused since t1 and t2 have + // different resize ops, although in this case, since the extents of t1 and + // t2 are the same, it should be relatively straightforward to fuse them + // together. + for (auto out_tv : ir_utils::filterByType(fusion->outputs())) { + if (out_tv == ref_tv) { + continue; + } + auto exprs = ValGraphBFS::getExprGroupsBetween( + broadcast_graph, + broadcast_graph.toGroups(ref_tv->getLogicalDomain()), + broadcast_graph.toGroups(out_tv->getLogicalDomain()), + /*require_all_to_visited=*/false) + .first; + for (const auto& [expr_g, dir] : exprs) { + if (expr_g->front()->isA()) { + std::stringstream msg; + msg << "Resize between reference and output not allowed."; + msg << " Reference: " << ref_tv->toString() + << ". Output: " << out_tv->toString() + << ". Resize: " << expr_g->front()->toString(); + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), msg.str()); + return false; + } + } + } + + // Disable the scheduler if there's a squeeze op. The loop option + // may also need to be enabled in that case, but that option is not + // turned on automatically yet. + if (ir_utils::hasOpsOfType(fusion)) { + return false; + } + + return true; +} + +std::unique_ptr ResizeScheduler::computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache) { + FUSER_PERF_SCOPE("ResizeScheduler::computeHeuristics"); + auto params = std::make_unique(SchedulerType::Resize); + params->cparams.index_type = runtime_info.getIndexType(); + return params; +} + +void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { + FUSER_PERF_SCOPE("ResizeScheduler::schedule"); + + FusionGuard fg(fusion); + + scheduler_utils::clearMemorySpace(fusion); + + scheduler_utils::cacheInputs(fusion, true); + scheduler_utils::cacheAndForkOutputs(fusion, true); + + auto resize_tensor_ops = ir_utils::getOpsOfType(fusion); + + IdModel id_model(fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + + // Replicate resize inputs if necessary to avoid conflicting + // propagations + const auto exclusivity_info_map = scheduler_tools::getNonExclusiveResizeInfo( + resize_tensor_ops, exact_graph); + for (auto resize_tensor_op : resize_tensor_ops) { + auto out_tv = resize_tensor_op->output(0)->as(); + if (exclusivity_info_map.count(out_tv) == 0) { + continue; + } + auto inp_tv = resize_tensor_op->input(0)->as(); + // Since cacheInput may skip caching if an input is used by + // slice/pad, inp_tv may be a fusion input, in which case it is + // not necessary to recompute the tensor. + if (inp_tv->isFusionInput()) { + continue; + } + auto inp_tv_copy = RecomputeTv::recompute(inp_tv); + ir_utils::replaceValInExprInputs(resize_tensor_op, inp_tv, inp_tv_copy); + } + + for (auto expr : fusion->exprs()) { + if (!expr->isOneOf()) { + continue; + } + + scheduler_tools::propagateResizeToInputs(expr); + } + + auto ref_tv = getReferenceTensor(fusion); + + // Just simple scheduling for now. + // TODO: Do something smarter. Can just use the pointwise scheduler? + + // Make sure the DID ID located at the outermost position + const auto outermost_pos = scheduler_utils::reorderDevicesToOuter(ref_tv); + + // Schedule only the remaining IDs + ref_tv->flatten(outermost_pos); + ref_tv->split(outermost_pos, 128); + ref_tv->split(outermost_pos, 1 << 14); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + // Propagate the reference to the other tensors. Note that the + // update flag is enabled so to workaround the resize propagation + // issue. This may not work if there's a tensor that is reshaped + // from the reference tensor, but that should not be the case as the + // reference is picked by the same routine used for the pointwise + // scheduler. + scheduler_tools::scheduleLoopDomainsLike( + fusion->allTvs(), + ref_tv->getLoopDomain(), + /*update_loop_domain_only=*/true); + + inlineMost(); + + markAliases(fusion); +} + +} // namespace nvfuser diff --git a/csrc/scheduler/resize.h b/csrc/scheduler/resize.h new file mode 100644 index 00000000000..b51ecf1e6dd --- /dev/null +++ b/csrc/scheduler/resize.h @@ -0,0 +1,41 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { + +class Fusion; +class SchedulerRuntimeInfo; +class HeuristicDataCache; + +class ResizeScheduler : public SchedulerEntry { + public: + bool canScheduleCompileTime(Fusion* fusion) override; + bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache = nullptr) override { + return true; + } + + std::unique_ptr computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache) override; + + void schedule(Fusion* fusion, const HeuristicParams* params) override; + + constexpr static SchedulerType schedulerType() { + return SchedulerType::Resize; + } +}; + +} // namespace nvfuser diff --git a/csrc/scheduler/scheduler_types.cpp b/csrc/scheduler/scheduler_types.cpp index 623d5a22697..cf9b974acf5 100644 --- a/csrc/scheduler/scheduler_types.cpp +++ b/csrc/scheduler/scheduler_types.cpp @@ -31,6 +31,8 @@ std::string toString(SchedulerType scheduler_type) { return "matmul"; case SchedulerType::ExprEval: return "expr_eval"; + case SchedulerType::Resize: + return "resize"; case SchedulerType::None: return "none"; default: diff --git a/csrc/scheduler/scheduler_types.h b/csrc/scheduler/scheduler_types.h index 275a1f372e7..caa389abb9a 100644 --- a/csrc/scheduler/scheduler_types.h +++ b/csrc/scheduler/scheduler_types.h @@ -56,15 +56,17 @@ enum class SchedulerType { InnerOuterPersistent, OuterPersistent, Transpose, - ExprEval + ExprEval, + Resize }; //! Define a schedule table to loop over all the heuristics in priority order. -constexpr std::array all_heuristics_in_priority_order = { +constexpr std::array all_heuristics_in_priority_order = { SchedulerType::ExprEval, SchedulerType::NoOp, SchedulerType::Matmul, SchedulerType::Reduction, + SchedulerType::Resize, SchedulerType::Transpose, SchedulerType::PointWise, SchedulerType::InnerPersistent, diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp new file mode 100644 index 00000000000..747fdbe43e6 --- /dev/null +++ b/csrc/scheduler/tools/domain_map.cpp @@ -0,0 +1,612 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +namespace nvfuser { +namespace scheduler_tools { + +namespace { +// Check if a root ID of a fusion input tensor that is indirectly +// accessed by ops such as torchGather needs to be mapped with +// a reference tensor. Select has a similar effect as squeeze as the +// indexed domain is removed, so the domain does not need to be mapped +// as long as the tensor is a fusion input. Similarly, in index_select +// and torchGather, if the output domain is a broadcast, it does not +// need to be mapped if not resolved. +bool canIgnoreIndexedInputDomainID( + TensorView* input_tv, + IterDomain* root_id, + const ComputeAtMap& ca_map) { + NVF_ERROR(input_tv->isFusionInput()); + for (auto use : input_tv->uses()) { + if (auto select = dynamic_cast(use)) { + if (root_id != select->getIndexedID()) { + return false; + } + } else if (auto index_select = dynamic_cast(use)) { + // If the root_id is an indexed ID, and the consumer ID may be a + // broadcast. In that case, nothing needs to be mapped if the + // consumer broadcast is not resolved + if (root_id != index_select->getIndexedID() || + !ca_map + .getConcreteMappedID( + index_select->getConsumerOfIndexedID(), + IdMappingMode::PERMISSIVE) + ->isBroadcast()) { + return false; + } + } else if (auto gather = dynamic_cast(use)) { + // TODO: Remove this. Once slice is used for torchGather, this + // should not be necessary. For now, it is necessary to not + // break the existing torchGather tests + if (!gather->exactSizes()) { + continue; + } + // If the root_id is an indexed ID, and the consumer ID may be a + // broadcast. In that case, nothing needs to be mapped if the + // consumer broadcast is not resolved + if (root_id != gather->getIndexedID() || + !ca_map + .getConcreteMappedID( + gather->getConsumerOfIndexedID(), IdMappingMode::PERMISSIVE) + ->isBroadcast()) { + return false; + } + } else { + // If the input TV is used by any other ops + return false; + } + } + + return true; +} + +// Grab all exact set mappings from consumer to producer domains of +// indexed accesses, e.g., index_select +std::unordered_multimap< + std::shared_ptr>, + std::shared_ptr>> +getIndexedConsumerToProducerMap(Fusion* fusion, const ComputeAtMap& ca_map) { + std::unordered_multimap< + std::shared_ptr>, + std::shared_ptr>> + indexed_id_map; + + for (auto expr : fusion->exprs()) { + if (auto gather = dynamic_cast(expr)) { + auto p_id = gather->getIndexedID(); + auto c_id = gather->getConsumerOfIndexedID(); + indexed_id_map.emplace( + ca_map.disjointSetOf(c_id, IdMappingMode::EXACT), + ca_map.disjointSetOf(p_id, IdMappingMode::EXACT)); + } else if (auto index_select = dynamic_cast(expr)) { + auto p_id = index_select->getIndexedID(); + auto c_id = index_select->getConsumerOfIndexedID(); + indexed_id_map.emplace( + ca_map.disjointSetOf(c_id, IdMappingMode::EXACT), + ca_map.disjointSetOf(p_id, IdMappingMode::EXACT)); + } else { + // Note there's no consumer ID for select. This means we can't + // just propagate from consumers to indexed producers. It seems + // it's necessary to schedule producers and consumers separately + // in those cases. + continue; + } + } + + return indexed_id_map; +} + +} // namespace + +DomainMap::DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(fusion) { + tvs_with_rfactor_ = scheduler_utils::getTVsWithNonReductionRFactor(fusion); +} + +// Determine if all IterDomains in input are mapped to the given tensor +bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) + const { + // Get concrete IDs for input root or logical domain + std::unordered_set in_concrete_ids; + for (auto in_id : input_tv->getLogicalDomain()) { + if (canIgnoreIndexedInputDomainID(input_tv, in_id, ca_map_)) { + continue; + } + + // Permissive map is required for the transpose scheduler to support cases + // like T0[I0, b] + T1[b, I1] + auto concrete = + ca_map_.getConcreteMappedID(in_id, IdMappingMode::PERMISSIVE); + + if (!concrete->isBroadcast() && !in_id->isReduction()) { + in_concrete_ids.insert(concrete); + } + } + + // Erase all input concrete IDs mapped to the output domain + // Ignore unresolved broadcast dimensions + eraseifInputMappedThroughRootDomainAndIndexing( + in_concrete_ids, tv->getLogicalDomain()); + + return in_concrete_ids.empty(); +} + +// Note: ideally we would want to check that reference_tv contains all iter +// domains in target_tv, so that transformation applied on reference_tv can be +// propagated to target_tv. But we don't have an easy way to check that. Instead +// of that, this function checks that all source iter domains involved in +// transformation on target_tv is covered by reference_tv. Source iter domains +// of TensorViews are IDs that doesn't have an definition and are producers of +// any IDs on the logical domain of the given TensorView. +// +// ------ +// +// e.g 0. +// T34 [i0, i1] +// T185 [i0, b2, i1] = broadcast(T34) +// T192 [i0, b3(ex), i1] = expand(T185) +// T198 [i0, b3(ex)*i1] = reshape(T192) +// output(T34) +// output(T198) +// +// if we consider taking T34 as reference_tv. T198 is the target_tv. We can't +// replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1 +// can't be reversed. The check in this function would give us T34 with source +// i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in +// T34. Hence we'll reject this reference_tv. +// +// ------ +// +// e.g 1. +// T0 [i0, i1] +// T1 [i2, i0, i1] +// T2 [i0*i1] = reshape(T0) +// T3 [b3, i0, i1] = broadcast(T0) +// T4 [i2, i0, i1] = add(T1, T3) +// output(T2) +// output(T4) +// +// the example above should be able to pick T4 as reference_tv. T2's source i0, +// i1 are both contained by the source of T4, so this example could be scheduled +// as a single fusion. +bool DomainMap::areAllTargetIdsCoveredBy( + TensorView* target_tv, + TensorView* reference_tv) const { + auto get_source_iter_domains = [this](const std::vector& ids) { + // traverse back to collect all disjoint set producer IDs for each ID in the + // logical domain of tv. + VectorOfUniqueEntries>> + all_producer_sets; + std::for_each(ids.begin(), ids.end(), [&](IterDomain* tv_logical_id) { + all_producer_sets.pushBack( + ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT)); + }); + all_producer_sets.pushBack( + ca_map_.getAllDisjointSetProducers(all_producer_sets)); + + std::vector source_ids; + // filtering all producer IDs with empty definition to get source iter + // domains + std::for_each( + all_producer_sets.vector().begin(), + all_producer_sets.vector().end(), + [&source_ids, + this](const std::shared_ptr>& + producer_set_ptr) { + IterDomain* producer_id = producer_set_ptr->front(); + if (ca_map_.uniqueExactDefinitions(producer_id).empty()) { + source_ids.push_back(producer_id); + } + }); + return source_ids; + }; + + // this contains all source iter domain that's covered by reference_tv, so + // it's safe for target_tv to have them. + std::unordered_set covered_source_ids; + for (IterDomain* source_id_ref : + get_source_iter_domains(reference_tv->getLogicalDomain())) { + covered_source_ids.insert(source_id_ref); + } + // It's safe to have unmapped broadcast IterDomain. There're quite a few tests + // expecting pointwise scheduler to handle this pattern + for (IterDomain* id_out : target_tv->getLogicalDomain()) { + if (id_out->isBroadcast()) { + NVF_ERROR( + id_out->definition() == nullptr || + id_out->definition()->isA()); + + // Note that ideally we should also be able to handle merge/split on + // broadcast IDs, so we should really move this skip inside the loop below + // `get_source_iter_domains(target_tv->getLogicalDomain())` and skip + // broadcast source IDs. currently we have the issue that split/merge does + // not preserve expanded broadcasts, see issue: + // https://github.com/NVIDIA/Fuser/issues/1126 + covered_source_ids.insert(id_out); + } + } + // Note: there's certain cases where it's safe to have dangling IDs, + // e.g + // T34 [i0, i1] + // T185 [i0, b2, i1] = broadcast(T34) + // T192 [i0, b3(ex), i1] = expand(T185) + // It's safe to propagate T34 to T192, since b3(ex) is not involved in the + // propagation. But this isn't generally safe. If the above example is changed + // to e.g + // T34 [i0, i1] + // T185 [i0, b2, i1] = broadcast(T34) + // T186 [i0, i4, i1] = ones({i0, i4, i1}) + // T193 [i0, i4, i1] = add(T185, T186) + // It's unsafe to propagate from T34 to T193, see issue + // https://github.com/NVIDIA/Fuser/issues/3542 + + // Check all source iter domain involved in producing target_tv + for (IterDomain* source_id_out : + get_source_iter_domains(target_tv->getLogicalDomain())) { + // NOTE: we use concrete id instead. This allows us to link indirect + // broadcast. So in the example below: + // input T0[ + // T2[i0, i2*i3] = T0[i0, i2, i3] + // T3[i0, i2*i3] = T1[i0, b0] + T2[i0, i2*i3] + // T4[i0, i9] = pad(T1[i0, b0]) + // We have i9 in T3 + // -> source ID b0 + // -> concrete map to i2*i3 + // -> source ID from i2*i3 to [i2, i3] + // So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1` + auto concrete_id_out = + ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE); + + // After mapping with PERMISSIVE map, `concrete_id_out` might no longer be a + // source ID. We project to source ID again from concrete_id_out. See test + // DomainMapBroadcastIssue3653 + // In the example above. `i2*i3` is not a source ID. Hence we needed to go + // through another projection to source IDs in order to map it to + // covered_source_ids. + for (IterDomain* concrete_source_id_out : + get_source_iter_domains({concrete_id_out})) { + // if we find any source_id_out that's not contained, it's possible our + // propagation would fail since transformation involving this iter + // domain can't be resolved. + if (!getMappedInputConcreteID( + covered_source_ids, concrete_source_id_out)) { + return false; + } + } + } + return true; +} + +// Reference domains must exactly match with the input domains. See +// also PR #661 +IterDomain* DomainMap::getMappedInputConcreteID( + const std::unordered_set& in_concrete_ids, + IterDomain* out_id) const { + auto in_concrete_id_iter = std::find_if( + in_concrete_ids.begin(), + in_concrete_ids.end(), + [&](IterDomain* in_concrete_id) { + return ca_map_.areMapped(in_concrete_id, out_id, IdMappingMode::EXACT); + }); + if (in_concrete_id_iter != in_concrete_ids.end()) { + return *in_concrete_id_iter; + } else { + return nullptr; + } +} + +// Erase input concrete ID if it is mapped to output ID +bool DomainMap::eraseIfMapped( + std::unordered_set& in_concrete_ids, + IterDomain* out_id) const { + auto mapped_input_conrete_id = + getMappedInputConcreteID(in_concrete_ids, out_id); + if (mapped_input_conrete_id != nullptr) { + in_concrete_ids.erase(mapped_input_conrete_id); + return true; + } else { + return false; + } +} + +void DomainMap::eraseifInputMappedThroughRootDomainAndIndexing( + std::unordered_set& in_ids, + const std::vector& ids) const { + // Use ComputeAtMap::getAllDisjointSetProducers to grab all producer + // IDs through rfactor exprs + VectorOfUniqueEntries>> + exact_sets; + std::for_each(ids.begin(), ids.end(), [&](IterDomain* id) { + exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); + }); + + // Traverse through indexed domains. + const auto indexed_id_multimap = + getIndexedConsumerToProducerMap(fusion_, ca_map_); + + VectorOfUniqueEntries>> + all_exact_sets_covered; + + // Back traverses through the exact map and indexed + // producer-consumer pairs + for (auto current_sets = exact_sets; !current_sets.empty();) { + auto producer_sets = ca_map_.getAllDisjointSetProducers(current_sets); + all_exact_sets_covered.pushBack(producer_sets); + + current_sets.clear(); + + // Further traversal if any of the new producer sets is a producer + // of indexed domains + for (const auto& producer_set : producer_sets) { + auto indexed_id_multimap_range = + indexed_id_multimap.equal_range(producer_set); + for (auto producer_of_producer_it = indexed_id_multimap_range.first; + producer_of_producer_it != indexed_id_multimap_range.second; + ++producer_of_producer_it) { + current_sets.pushBack(producer_of_producer_it->second); + } + } + } + + for (const auto& exact_set_ptr : all_exact_sets_covered) { + auto exact_concrete_id = ca_map_.getConcreteMappedID( + exact_set_ptr->front(), IdMappingMode::EXACT); + eraseIfMapped(in_ids, exact_concrete_id); + } +} + +// Find any id in domain that maps with target id +IterDomain* DomainMap::anyMapped( + const std::vector& domain, + IterDomain* target) const { + for (auto id : domain) { + if (ca_map_.areMapped(id, target, IdMappingMode::EXACT)) { + return id; + } + } + return nullptr; +} + +// Determine if output TensorView is a valid reference tensor for this fusion. +// The reference tensor must map to all the iterDomains in each input and +// output +bool DomainMap::isValidReference(TensorView* tv) const { + for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { + if (input_tv->uses().empty()) { + continue; + } + // TODO: Same backward traversal from tv is done for all input + // tvs. Consider doing the analysis one for all inputs + if (!areAllInputIdsMappedTo(input_tv, tv)) { + return false; + } + } + // The check on outputs are optional, transpose scheduler might propose a + // secondary reference that only applies to a subset of IO tensors. Ideally + // we should have a more robust check and consider the IO groups instead of + // blindly skip outputs. + for (auto output_tv : + ir_utils::filterByType(fusion_->outputs())) { + // no need to check for self. + if (output_tv == tv) { + continue; + } + if (!areAllTargetIdsCoveredBy(output_tv, tv)) { + return false; + } + } + return true; +} + +TensorView* PointwiseDomainMap::findReferenceTensor( + int64_t minimum_num_axes) const { + TensorView* result = nullptr; + int64_t max_dims = -1; + for (auto output_tv : + ir_utils::filterByType(fusion_->outputs())) { + if (isValidReference(output_tv) && + hasMinimumSize(output_tv, minimum_num_axes) && + !output_tv->isFusionInput()) { + int64_t n_dims = scheduler_utils::nLogicalDims(output_tv); + if (n_dims > max_dims) { + result = output_tv; + max_dims = n_dims; + } + } + } + return result; +} + +TensorView* TransposeDomainMap::findReferenceFor( + const std::vector& group) const { + TensorView* result = nullptr; + int64_t max_dims = -1; + for (auto tv : group) { + // since transpose scheduler have different set of reference, we skip IDs + // coverage check of the reference on outputs of the fusion. Note that + // this is not ideal, we would want to instead have reference tensor + // checked against all its target IO tensors. + // TODO: open an issue for this one. transpose scheduler is not supposed + // to reuse pointwise_utils::DomainMap::isValidRefrence. This function is + // too restrictive and doesn't align well with the scheme of transpose + // scheduler + if (isValidReference(tv)) { + int64_t dims = scheduler_utils::nLogicalDims(tv); + if (dims > max_dims) { + result = tv; + max_dims = dims; + } + } + } + return result; +} + +IterDomain* TransposeDomainMap::getMappedAllocDimIn( + TensorView* tv, + IterDomain* root_dim) const { + // Find the id mapped to `Allocation Domain` + const auto& alloc_dom = tv->getMaybeAllocationDomain(); + IterDomain* mapped_id = nullptr; + for (auto i : c10::irange(alloc_dom.size())) { + if (ca_map_.areMapped(alloc_dom[i], root_dim, IdMappingMode::INNERMOST)) { + mapped_id = alloc_dom[i]; + break; + } + } + return mapped_id; +} + +bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { + FusionGuard fg(fusion); + TransposeDomainMap domain_map(fusion); + auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); + if (grouped_inputs_outputs.size() < 2) { + return false; + } + auto ref1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]); + auto ref2 = domain_map.findReferenceFor(grouped_inputs_outputs[1]); + if (ref1 == nullptr || ref2 == nullptr) { + return false; + } + // reference 1 is the global reference, so it must have dim mapped the + // innermost dim of both groups + auto innermost2 = scheduler_utils::innerMostAllocDim(ref2); + return domain_map.getMappedAllocDimIn(ref1, innermost2) != nullptr; +} + +int64_t TransposeDomainMap::getInnerLeafDim( + TensorView* tv, + IterDomain* root_dim) const { + // TODO: ideally we should be mapping to loop domain directly here. + // However, our current compute at map is constructed before loop domain is + // transformed. So the mapping here would require a new compute at map to be + // constructed from the updated fusion. We'll revisit this once our id graph + // refactor is done. + auto mapped_id = getMappedAllocDimIn(tv, root_dim); + NVF_ERROR( + mapped_id != nullptr, + "Can not find ID mapped to ", + root_dim, + " in tensor ", + tv); + std::vector replay_exprs = StmtSort::getExprsBetween( + {mapped_id}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + // Project the root id to loop id. Similar to projectIdToRFactor. + for (auto* expr : replay_exprs) { + if (auto* split = dynamic_cast(expr)) { + if (split->in() == mapped_id) { + if (split->inner()->extent()->isOneInt() && + !split->outer()->extent()->isOneInt()) { + mapped_id = split->outer(); + } else { + mapped_id = split->inner(); + } + } + } else if (auto* merge = dynamic_cast(expr)) { + // Merge with size-1 dimension is not supposed to be here, reshape would + // map this to a squeeze. This is a conservative assert, we can relaxed + // it and support with mapping it to out. + NVF_ERROR( + !merge->inner()->extent()->isOneInt(), + "merge with size-1 dimension is supposed to be translated to squeeze by reshape"); + if (merge->inner() == mapped_id) { + mapped_id = merge->out(); + } + } else if (auto* resize = dynamic_cast(expr)) { + if (resize->in() == mapped_id) { + mapped_id = resize->out(); + } + } + } + + // Find the position of the loop id + const auto& dom = tv->getLoopDomain(); + for (auto i : c10::irange(dom.size())) { + if (dom[i] == mapped_id) { + return static_cast(i); + } + } + return -1; +} + +std::vector> TransposeDomainMap:: + groupInputsOutputsByInnerDim() const { + std::vector> groups; + auto output_tvs = ir_utils::filterByType(fusion_->outputs()); + auto input_tvs = ir_utils::filterByType(fusion_->inputs()); + std::unordered_set grouped; + std::array tv_filtered_groups = { + &output_tvs, &input_tvs}; + for (auto tv_filtered_group : tv_filtered_groups) { + for (auto tv : *tv_filtered_group) { + if (tv->isFusionInput() && tv->uses().empty()) { + continue; + } + if (grouped.count(tv) > 0) { + continue; + } + groups.emplace_back(std::vector{tv}); + grouped.emplace(tv); + // We only want to grab the inner-most dimension, because we don't want + // tensors with different inner-most dimension to be put in the same + // group. For example, if we have: + // T2[i1, i3*i2] = relu(view(transpose(T1[i1, i2, i3]))) + // then we don't want T1 and T2 to be in the same group. + // + // But we don't want to check contiguity. For example, if we have: + // T1[i1, i2, i3] (contiguous) + T2[i1, i2, i3] (discontiguous) + // Then we still want to T1 and T2 to be grouped together. + auto group = + scheduler_utils::getInputsOutputsWithInnerDim(tv, true, false); + if (group.empty()) { + // In case that the inner most dim of tv is not found (for example, tv + // is a fusion input with only reductions), we just return a null + // result which will tell the scheduler to reject the fusion + return {}; + } + for (auto member_tv : group) { + if (grouped.count(member_tv) == 0) { + grouped.emplace(member_tv); + groups.back().emplace_back(member_tv); + } else if (member_tv != tv) { + // Ambiguous grouping. This should only happen at `canSchedule`, so + // we just return a null result which will tell the scheduler to + // reject the fusion + return {}; + } + } + } + } + std::stable_sort( + groups.begin(), + groups.end(), + [](const std::vector& v1, + const std::vector& v2) { return v1.size() > v2.size(); }); + return groups; +} + +IterDomain* TransposeDomainMap::getMappedInputConcreteID( + const std::unordered_set& in_concrete_ids, + IterDomain* out_id) const { + auto in_concrete_id_iter = std::find_if( + in_concrete_ids.begin(), + in_concrete_ids.end(), + [&](IterDomain* in_concrete_id) { + return ca_map_.areMapped( + in_concrete_id, out_id, IdMappingMode::PERMISSIVE); + }); + if (in_concrete_id_iter != in_concrete_ids.end()) { + return *in_concrete_id_iter; + } else { + return nullptr; + } +} + +} // namespace scheduler_tools +} // namespace nvfuser diff --git a/csrc/scheduler/tools/domain_map.h b/csrc/scheduler/tools/domain_map.h new file mode 100644 index 00000000000..8a8ccb33e91 --- /dev/null +++ b/csrc/scheduler/tools/domain_map.h @@ -0,0 +1,156 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include +#include + +namespace nvfuser { + +class Fusion; +class TensorView; +class IterDomain; + +namespace scheduler_tools { + +// DomainMap uses the ComputeAtMap to find a reference TensorView +// that maps to all IterDomains in the fusion. +class DomainMap { + public: + DomainMap(Fusion* fusion); + virtual ~DomainMap() = default; + + const ComputeAtMap& getComputeAtMap() const { + return ca_map_; + } + + // Determine if a TensorView is a valid reference tensor for this fusion. + // The reference tensor must map to all the iterDomains in each input and + // output. + bool isValidReference(TensorView* tv) const; + + protected: + // Determine if all IterDomains are mapped between input and the given tvs + bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) + const; + + // Determine if all source IterDomains in target_tv are contained by the + // reference_tv, this ensures transformations from reference_tv can be + // propagated to target_tv + bool areAllTargetIdsCoveredBy(TensorView* target_tv, TensorView* reference_tv) + const; + + virtual IterDomain* getMappedInputConcreteID( + const std::unordered_set& in_concrete_ids, + IterDomain* out_id) const; + + // Erase input concrete ID if it is mapped to output ID + bool eraseIfMapped( + std::unordered_set& in_concrete_ids, + IterDomain* out_id) const; + + // Check if in_ids are mapped to ids through any root domain as + // well as indirectly accessed domains with ops like torchGather + void eraseifInputMappedThroughRootDomainAndIndexing( + std::unordered_set& in_ids, + const std::vector& ids) const; + + // Find any id in domain that maps with target id + IterDomain* anyMapped( + const std::vector& domain, + IterDomain* target) const; + + Fusion* fusion_ = nullptr; + ComputeAtMap ca_map_; + std::vector tvs_with_rfactor_; +}; + +class PointwiseDomainMap : public scheduler_tools::DomainMap { + public: + using scheduler_tools::DomainMap::DomainMap; + + // The pointwise scheduler heuristics requires a minimum number of axes. + // The output reference tensor should respect this requirement. + TensorView* findReferenceTensor(int64_t minimum_num_axes = 0) const; + + private: + bool hasMinimumSize(TensorView* tv, int64_t num_axes) const { + NVF_ERROR(tv != nullptr); + return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes); + } +}; + +// DomainMap uses the ComputeAtMap to find a reference TensorView +// that maps to all iterDomains in the fusion. +class TransposeDomainMap : public scheduler_tools::DomainMap { + public: + using scheduler_tools::DomainMap::DomainMap; + + // Note that this may not be able to find any reference if any + // tensor in the group is only connected with an input through + // rfactor or gather-like indexing ops. It is because + // isValidReference is based a backward traversal, so there may not + // be a traversal path to an input. This type of analysis is + // expected to be possible much more easily with the new indexing + // graph (#32), so we should revisit once it becomes available. + TensorView* findReferenceFor(const std::vector& group) const; + + IterDomain* getMappedAllocDimIn(TensorView* tv, IterDomain* root_dim) const; + + static bool hasAtLeastTwoValidGroups(Fusion* fusion); + + // scheduler assumes inner loop dimension on tv is an exact mapping, when the + // mapping cannot be resolved, we'll return a `-1` + int64_t getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const; + + // Group inputs and outputs of a fusion by its inner most domain. For example + // inputs: t0, t1 + // t2 = transpose(t1) + // t3 = t0 + t2 + // t4 = sin(t0) + // t5 = cos(t1) + // outputs: t3, t4, t5 + // + // Then we should have group {t0, t3, t4} and {t1, t5} + // + // The returned groups are sorted in descending size. If the sizes of two + // group are equal, then we sort them by their members in the following order: + // output[0], output[1], ..., input[0], input[1], ... + // That is, {ouput[0], output[2]} will be in front of {ouput[1], output[3]} + // The order here must be deterministic, because in transpose heuristics, we + // have `vectorize_factor1` and `vectorize_factor2` and we need to be sure + // that `1` and `2` are assigned to the same group across runs. + // + // In the case where view is present in the graph, there are two cases: if the + // view doesn't touch any inner dimension of any group, then the support of it + // is trivial. In the case where view actually touches an inner-most dim, we + // keep track of the inner-most dimension of view's split and merges. + // + // For example, if you have: + // T0 [2, 3, 5] <-- input + // T1 [2, 5, 3] <-- input + // T2 [2, 5, 3] = transpose(T0) + T1 + // T3 [2, 15] = view(T2) + // output <-- T3 + // + // Then T3 should be in the same group with T1, and T0 should have + // different group with T1 and T3. + std::vector> groupInputsOutputsByInnerDim() const; + + // In the transpose scheculing, unlike the pointwise scheduling, the + // permissive map is required to find reference tensors. See also PR + // #661 + IterDomain* getMappedInputConcreteID( + const std::unordered_set& in_concrete_ids, + IterDomain* out_id) const override; +}; + +} // namespace scheduler_tools +} // namespace nvfuser diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index f04a2f2271e..327d432ad2c 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -238,8 +238,9 @@ void LoopDomainScheduler::schedule(TensorView* tv) const { } const auto path_from_ref = getReplayPath(tv); - const ExprGroups all_existing_expr_groups = - graph().toGroups(tv->domain()->allExprs()); + const ExprGroups all_existing_expr_groups = update_loop_domain_only_ + ? ExprGroups{} + : graph().toGroups(tv->domain()->allExprs()); // Replay the path on the target tensor for (const auto& [expr_g, dir] : path_from_ref) { diff --git a/csrc/scheduler/tools/resize_utils.cpp b/csrc/scheduler/tools/resize_utils.cpp index cc914e5684b..ddecf6bcb13 100644 --- a/csrc/scheduler/tools/resize_utils.cpp +++ b/csrc/scheduler/tools/resize_utils.cpp @@ -50,7 +50,9 @@ void propagateResizeToInputs(Expr* resize_tensor_op) { // Before doing so, all the dependent tensors need to have the exact-mapped // loop domain. scheduler_tools::scheduleLoopDomainsLike( - tvs_to_schedule, producer_tv->getLoopDomain()); + tvs_to_schedule, + producer_tv->getLoopDomain(), + /*update_loop_domain_only=*/true); // Now that all the dependent tensors have the uniform, exact-mapped // loop domains, we just need to propagte the specific Resize ops of @@ -66,5 +68,118 @@ void propagateResizeToInputs(Expr* resize_tensor_op) { } } +std::unordered_map getNonExclusiveResizeInfo( + const std::vector& ordered_resize_tensor_ops, + const ValGraph& exact_graph) { + NVF_ERROR(!ordered_resize_tensor_ops.empty()); + Fusion* fusion = ordered_resize_tensor_ops[0]->fusion(); + + std::unordered_map non_exclusive_resizes; + + std::unordered_set inputs{ + fusion->inputs().begin(), fusion->inputs().end()}; + + auto get_root_to_logical_resizes = + [&exact_graph](TensorView* tv) -> ValGroups { + // This should be only used for outputs of resize-based ops, + // so it should always have a root domain. + NVF_ERROR(tv->hasRoot()); + auto out_tv_root_to_logical_exprs = DependencyCheck::getAllExprsBetween( + {tv->getRootDomain().begin(), tv->getRootDomain().end()}, + {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); + ValGroups resize_inp_ids; + for (auto resize : + ir_utils::filterByType(out_tv_root_to_logical_exprs)) { + resize_inp_ids.pushBack(exact_graph.toGroup(resize->in())); + } + return resize_inp_ids; + }; + + // Traverse the ops in a topological order + for (Expr* resize_tensor_op : ordered_resize_tensor_ops) { + auto inp_tv = dynamic_cast(resize_tensor_op->inputs().at(0)); + auto out_tv = dynamic_cast(resize_tensor_op->outputs().at(0)); + + ResizeExclusivityInfo info; + + ValGroups resize_inp_ids = get_root_to_logical_resizes(out_tv); + NVF_ERROR(!resize_inp_ids.empty()); + + auto dep_vals = + DependencyCheck::getAllValsBetween(inputs, std::vector{inp_tv}); + + // For each tensor that inp_tv depends on, check if the resize op + // is considered non-exclusive with respect to the tensor. That + // is, if propagation of the resize may result in externally + // visible changes through the tensor, the resize is considered + // non-exclusive. + for (auto dep_tv : ir_utils::filterByType(dep_vals)) { + bool maybe_non_exclusive = dep_tv->isFusionOutput(); + + if (!maybe_non_exclusive) { + // If a dependent tv has a consumer that inp_tv does not + // depend on, propagation of resize would escape to outputs, + // which needs to be avoided. + for (auto consumer_tv : ir_utils::consumerTvsOf(dep_tv)) { + // We are interested in if resized IDs are used by other tensors + // than out_tv + if (consumer_tv != out_tv && + std::find(dep_vals.begin(), dep_vals.end(), consumer_tv) == + dep_vals.end()) { + maybe_non_exclusive = true; + break; + } + } + } + + if (!maybe_non_exclusive) { + continue; + } + + // dep_tv potentially is either a fusion output or it has a + // consumer outside of the dependency set to the resized + // tensor. Propagating the resize to dep_tv should be + // avoided. However, if the dep_tv iter domain that corresponds + // to the resized ID is a broadcast or there's no such ID, it + // should still be safe to consider the resize op exclusive as + // there's no iter domain to resize. For a concrete example, see + // ResizeSchedulerTest.PropagateMultipleSlicesToInputs4. + const auto inp_tv_logical_groups = + exact_graph.toGroups(inp_tv->getLogicalDomain()); + const auto dep_tv_logical_groups = + exact_graph.toGroups(dep_tv->getLogicalDomain()); + auto vals_between = getValsBetween( + {inp_tv_logical_groups.begin(), inp_tv_logical_groups.end()}, + {dep_tv_logical_groups.begin(), dep_tv_logical_groups.end()}, + exact_graph); + + for (const ValGroup& resize_inp_id : resize_inp_ids) { + if (std::find( + vals_between.begin(), vals_between.end(), resize_inp_id) == + vals_between.end()) { + // This resize can be ignored as there's no corresponding ID + // in the dep tv + continue; + } + + // This resize input ID is not exclusively used + info.non_exclusive_dep_tvs.push_back(dep_tv); + info.resized_ids.pushBack(resize_inp_id); + } + } + + if (!info.non_exclusive_dep_tvs.empty()) { + NVF_ERROR(non_exclusive_resizes.emplace(out_tv, info).second); + } + + // Analysis of exclusiveness until in_tv is done. Following + // resize-based tensor ops do not need to check the same section + // of the fusion and can start from out_tv. + inputs.insert(out_tv); + } + + return non_exclusive_resizes; +} + } // namespace scheduler_tools } // namespace nvfuser diff --git a/csrc/scheduler/tools/resize_utils.h b/csrc/scheduler/tools/resize_utils.h index cf03083ad4f..b9afed5effa 100644 --- a/csrc/scheduler/tools/resize_utils.h +++ b/csrc/scheduler/tools/resize_utils.h @@ -7,9 +7,12 @@ // clang-format on #pragma once +#include + namespace nvfuser { class Expr; +class TensorView; namespace scheduler_tools { @@ -19,5 +22,97 @@ namespace scheduler_tools { // fusion inputs are skipped as their loop domains don't matter. void propagateResizeToInputs(Expr* resize_op); +// Given a topologically ordered list of resize-based tensor ops such +// as slice and pad, check if they can be propagated to fusion inputs +// exclusively without causing any visible side effect. For example, +// if a tensor is sliced and also is used to produce an output without +// the slicing, the slice is considered non exclusive as the slice +// input has the other visible consumer. Propagating the resize of the +// slice to the slice input is invalid since the output computed from +// the slice input depends on the full iteration space. +// +// For example, consider the following case: +// +// t0 = makeSymbolicTensor(1) +// fusion.addInput(t0) +// t1 = t0 + 1 +// t2 = t1[1:10] +// t3 = t1 + 1 +// fusion.addOutput(t2) +// fusion.addOutput(t3) +// +// In this case, propating the resize op of the slice would alter t1, +// which would in turn affect t3, which is a fusion output. Since the +// change would be visible due to the change of t3, this resize op is +// considered non-exclusive. +// +// Consider a slightly different case as shown below: +// +// t0 = makeSymbolicTensor(1) +// fusion.addInput(t0) +// t1 = t0[1:10] +// t2 = t0 + 1 +// fusion.addOutput(t1) +// fusion.addOutput(t2) +// +// Note that the slice is directly done with the fusion input. Since +// we do not propagate resize ops to fusion inputs, this can be +// considered exclusive. However, this is also considered +// non-exclusive since the actual scheduling inserts a cache after t0, +// which can cause a visible side effect if the resize is propagated. +// +// Another non-exclusivness comes from dependent fusion outputs. For +// example, if a slice input depends on a fusion output, propagation +// would alter the fusion output. Consider a case like: +// +// t0 = makeSymbolicTensor(1) +// fusion.addInput(t0) +// t1 = t0 + 1 +// t2 = t1[1:10] // slice +// fusion.addOutput(t1) +// fusion.addOutput(t2) +// +// If the resize op for the slice is propagated to t1, only the +// section of [1:10] would be computed. Since that would change a +// fusion output, the resize op is considered non-exclusive. +// +// When there's a chain of resize-based ops, for example: +// +// t0 = makeSymbolicTensor(1) +// fusion.addInput(t0) +// t1 = t0 + 1 +// t2 = t1[1:10] +// t3 = t2[2:5] +// t4 = t1 + 1 +// fusion.addOutput(t3) +// fusion.addOutput(t4) +// +// We do not consider the second slice as non-exclusive as +// long as the first slice is considered non-exclusive. This will be +// important when resolving the non-exclusiveness by replication. +// +// The function returns a map from tensors that are outputs to +// non-exclusive ops to ResizeExclusivityInfo. This map will be +// used to resolve the non-exclusiveness by replication. +struct ResizeExclusivityInfo { + // Dependent tensors that should not be resized + std::vector non_exclusive_dep_tvs; + // ID groups of resize input IDs + ValGroups resized_ids; + + bool operator==(const ResizeExclusivityInfo& other) const { + return non_exclusive_dep_tvs == other.non_exclusive_dep_tvs && + resized_ids == other.resized_ids; + } + + bool operator!=(const ResizeExclusivityInfo& other) const { + return !(*this == other); + } +}; + +std::unordered_map getNonExclusiveResizeInfo( + const std::vector& ordered_resize_tensor_ops, + const ValGraph& exact_graph); + } // namespace scheduler_tools } // namespace nvfuser diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 7e320f99a91..bef6b2b7def 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -68,12 +68,6 @@ bool TransposeScheduler::canScheduleCompileTime(Fusion* fusion) { } } - if (!hasAtLeastTwoValidGroups(fusion)) { - scheduler_debug_utils::canScheduleRejectReason( - schedulerType(), "cannot find two mismatching inner most dimensions"); - return false; - } - if (ir_utils::hasAnyReductionOps(fusion)) { scheduler_debug_utils::canScheduleRejectReason( schedulerType(), "no support for reduction ops"); @@ -87,6 +81,12 @@ bool TransposeScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } + if (!hasAtLeastTwoValidGroups(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "cannot find two mismatching inner most dimensions"); + return false; + } + return true; } @@ -153,230 +153,6 @@ bool hasSmallTransposeDimensions( !params->dims_merged_with_2.empty(); } -// DomainMap uses the ComputeAtMap to find a reference TensorView -// that maps to all iterDomains in the fusion. -class DomainMap : public pointwise_utils::DomainMap { - public: - using pointwise_utils::DomainMap::DomainMap; - - // Note that this may not be able to find any reference if any - // tensor in the group is only connected with an input through - // rfactor or gather-like indexing ops. It is because - // isValidReference is based a backward traversal, so there may not - // be a traversal path to an input. This type of analysis is - // expected to be possible much more easily with the new indexing - // graph (#32), so we should revisit once it becomes available. - TensorView* findReferenceFor(const std::vector& group) const { - TensorView* result = nullptr; - int64_t max_dims = -1; - for (auto tv : group) { - if (isValidReference(tv)) { - int64_t dims = (int64_t)pointwise_utils::nRootDims(tv); - if (dims > max_dims) { - result = tv; - max_dims = dims; - } - } - } - return result; - } - - IterDomain* getMappedAllocDimIn(TensorView* tv, IterDomain* root_dim) const { - // Find the id mapped to `Allocation Domain` - const auto& alloc_dom = tv->getMaybeAllocationDomain(); - IterDomain* mapped_id = nullptr; - for (auto i : c10::irange(alloc_dom.size())) { - if (ca_map_.areMapped(alloc_dom[i], root_dim, IdMappingMode::INNERMOST)) { - mapped_id = alloc_dom[i]; - break; - } - } - return mapped_id; - } - - static bool hasAtLeastTwoValidGroups(Fusion* fusion) { - FusionGuard fg(fusion); - DomainMap domain_map(fusion); - auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); - if (grouped_inputs_outputs.size() < 2) { - return false; - } - auto ref1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]); - auto ref2 = domain_map.findReferenceFor(grouped_inputs_outputs[1]); - if (ref1 == nullptr || ref2 == nullptr) { - return false; - } - // reference 1 is the global reference, so it must have dim mapped the - // innermost dim of both groups - auto innermost2 = scheduler_utils::innerMostAllocDim(ref2); - return domain_map.getMappedAllocDimIn(ref1, innermost2) != nullptr; - } - - // scheduler assumes inner loop dimension on tv is an exact mapping, when the - // mapping cannot be resolved, we'll return a `-1` - int64_t getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const { - // TODO: ideally we should be mapping to loop domain directly here. - // However, our current compute at map is constructed before loop domain is - // transformed. So the mapping here would require a new compute at map to be - // constructed from the updated fusion. We'll revisit this once our id graph - // refactor is done. - auto mapped_id = getMappedAllocDimIn(tv, root_dim); - NVF_ERROR( - mapped_id != nullptr, - "Can not find ID mapped to ", - root_dim, - " in tensor ", - tv); - std::vector replay_exprs = StmtSort::getExprsBetween( - {mapped_id}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); - // Project the root id to loop id. Similar to projectIdToRFactor. - for (auto* expr : replay_exprs) { - if (auto* split = dynamic_cast(expr)) { - if (split->in() == mapped_id) { - if (split->inner()->extent()->isOneInt() && - !split->outer()->extent()->isOneInt()) { - mapped_id = split->outer(); - } else { - mapped_id = split->inner(); - } - } - } else if (auto* merge = dynamic_cast(expr)) { - // Merge with size-1 dimension is not supposed to be here, reshape would - // map this to a squeeze. This is a conservative assert, we can relaxed - // it and support with mapping it to out. - NVF_ERROR( - !merge->inner()->extent()->isOneInt(), - "merge with size-1 dimension is supposed to be translated to squeeze by reshape"); - if (merge->inner() == mapped_id) { - mapped_id = merge->out(); - } - } else if (auto* resize = dynamic_cast(expr)) { - if (resize->in() == mapped_id) { - mapped_id = resize->out(); - } - } - } - - // Find the position of the loop id - const auto& dom = tv->getLoopDomain(); - for (auto i : c10::irange(dom.size())) { - if (dom[i] == mapped_id) { - return static_cast(i); - } - } - return -1; - } - - // Group inputs and outputs of a fusion by its inner most domain. For example - // inputs: t0, t1 - // t2 = transpose(t1) - // t3 = t0 + t2 - // t4 = sin(t0) - // t5 = cos(t1) - // outputs: t3, t4, t5 - // - // Then we should have group {t0, t3, t4} and {t1, t5} - // - // The returned groups are sorted in descending size. If the sizes of two - // group are equal, then we sort them by their members in the following order: - // output[0], output[1], ..., input[0], input[1], ... - // That is, {ouput[0], output[2]} will be in front of {ouput[1], output[3]} - // The order here must be deterministic, because in transpose heuristics, we - // have `vectorize_factor1` and `vectorize_factor2` and we need to be sure - // that `1` and `2` are assigned to the same group across runs. - // - // In the case where view is present in the graph, there are two cases: if the - // view doesn't touch any inner dimension of any group, then the support of it - // is trivial. In the case where view actually touches an inner-most dim, we - // keep track of the inner-most dimension of view's split and merges. - // - // For example, if you have: - // T0 [2, 3, 5] <-- input - // T1 [2, 5, 3] <-- input - // T2 [2, 5, 3] = transpose(T0) + T1 - // T3 [2, 15] = view(T2) - // output <-- T3 - // - // Then T3 should be in the same group with T1, and T0 should have - // different group with T1 and T3. - std::vector> groupInputsOutputsByInnerDim() const { - std::vector> groups; - auto output_tvs = ir_utils::filterByType(fusion_->outputs()); - auto input_tvs = ir_utils::filterByType(fusion_->inputs()); - std::unordered_set grouped; - std::array tv_filtered_groups = { - &output_tvs, &input_tvs}; - for (auto tv_filtered_group : tv_filtered_groups) { - for (auto tv : *tv_filtered_group) { - if (tv->isFusionInput() && tv->uses().empty()) { - continue; - } - if (grouped.count(tv) > 0) { - continue; - } - groups.emplace_back(std::vector{tv}); - grouped.emplace(tv); - // We only want to grab the inner-most dimension, because we don't want - // tensors with different inner-most dimension to be put in the same - // group. For example, if we have: - // T2[i1, i3*i2] = relu(view(transpose(T1[i1, i2, i3]))) - // then we don't want T1 and T2 to be in the same group. - // - // But we don't want to check contiguity. For example, if we have: - // T1[i1, i2, i3] (contiguous) + T2[i1, i2, i3] (discontiguous) - // Then we still want to T1 and T2 to be grouped together. - auto group = - scheduler_utils::getInputsOutputsWithInnerDim(tv, true, false); - if (group.empty()) { - // In case that the inner most dim of tv is not found (for example, tv - // is a fusion input with only reductions), we just return a null - // result which will tell the scheduler to reject the fusion - return {}; - } - for (auto member_tv : group) { - if (grouped.count(member_tv) == 0) { - grouped.emplace(member_tv); - groups.back().emplace_back(member_tv); - } else if (member_tv != tv) { - // Ambiguous grouping. This should only happen at `canSchedule`, so - // we just return a null result which will tell the scheduler to - // reject the fusion - return {}; - } - } - } - } - std::stable_sort( - groups.begin(), - groups.end(), - [](const std::vector& v1, - const std::vector& v2) { - return v1.size() > v2.size(); - }); - return groups; - } - - // In the transpose scheculing, unlike the pointwise scheduling, the - // permissive map is required to find reference tensors. See also PR - // #661 - IterDomain* getMappedInputConcreteID( - const std::unordered_set& in_concrete_ids, - IterDomain* out_id) const override { - auto in_concrete_id_iter = std::find_if( - in_concrete_ids.begin(), - in_concrete_ids.end(), - [&](IterDomain* in_concrete_id) { - return ca_map_.areMapped( - in_concrete_id, out_id, IdMappingMode::PERMISSIVE); - }); - if (in_concrete_id_iter != in_concrete_ids.end()) { - return *in_concrete_id_iter; - } else { - return nullptr; - } - } -}; - // Note: [Supporting small transpose dimensions] // We prefer to make tiles of size 32x32 if there are enough elements to achieve // good occupancy, otherwise, we will use tile size 8x8. In both cases, it is @@ -559,13 +335,17 @@ HeuristicDataCacheEntry getDomainMap( Fusion* fusion) { auto domain_map_entry = HeuristicDataCacheEntry( - data_cache, - [fusion]() { return std::make_unique(fusion); }); + data_cache, [fusion]() { + return std::make_unique( + fusion); + }); return domain_map_entry; } HeuristicDataCacheEntry -getInputsOutputsGroups(HeuristicDataCache* data_cache, DomainMap& domain_map) { +getInputsOutputsGroups( + HeuristicDataCache* data_cache, + scheduler_tools::TransposeDomainMap& domain_map) { auto grouped_inputs_outputs_entry = HeuristicDataCacheEntry< HeuristicCompileTime::InputsOutputsInnerDimGroups>( data_cache, [&domain_map]() { @@ -584,7 +364,7 @@ getInputsOutputsGroups(HeuristicDataCache* data_cache, DomainMap& domain_map) { HeuristicDataCacheEntry getReferenceTensors( HeuristicDataCache* data_cache, - DomainMap& domain_map, + scheduler_tools::TransposeDomainMap& domain_map, std::vector>& grouped_inputs_outputs) { auto reference_tensors_entry = HeuristicDataCacheEntry( @@ -609,7 +389,7 @@ std::pair, int64_t> getShapeInReference( HeuristicDataCache* data_cache, SchedulerRuntimeInfo& runtime_info, TensorView* reference, - DomainMap& domain_map) { + scheduler_tools::TransposeDomainMap& domain_map) { auto ref_logical = reference->getLogicalDomain(); std::vector shape_in_ref; shape_in_ref.reserve(reference->nDims()); @@ -635,7 +415,7 @@ getInnerMostDimInfoInReference( HeuristicDataCache* data_cache, const std::vector& group_references, TensorView* global_reference, - DomainMap& domain_map) { + scheduler_tools::TransposeDomainMap& domain_map) { auto innermost_info_entry = HeuristicDataCacheEntry( data_cache, [&]() { @@ -659,7 +439,8 @@ std::string getTransposeRuntimeRejectReason( HeuristicDataCache* data_cache, SchedulerRuntimeInfo& runtime_info) { auto domain_map_entry = getDomainMap(data_cache, fusion); - auto& domain_map = dynamic_cast(domain_map_entry.get()); + auto& domain_map = dynamic_cast( + domain_map_entry.get()); auto grouped_inputs_outputs_entry = getInputsOutputsGroups(data_cache, domain_map); auto grouped_inputs_outputs = grouped_inputs_outputs_entry.get(); @@ -789,7 +570,7 @@ std::string getTransposeRuntimeRejectReason( } // namespace bool hasAtLeastTwoValidGroups(Fusion* fusion) { - return DomainMap::hasAtLeastTwoValidGroups(fusion); + return scheduler_tools::TransposeDomainMap::hasAtLeastTwoValidGroups(fusion); } std::unique_ptr getTransposeHeuristics( @@ -802,7 +583,8 @@ std::unique_ptr getTransposeHeuristics( const auto index_type = runtime_info.getIndexType(); auto domain_map_entry = getDomainMap(data_cache, fusion); - auto& domain_map = dynamic_cast(domain_map_entry.get()); + auto& domain_map = dynamic_cast( + domain_map_entry.get()); auto grouped_inputs_outputs_entry = getInputsOutputsGroups(data_cache, domain_map); auto grouped_inputs_outputs = grouped_inputs_outputs_entry.get(); @@ -990,12 +772,12 @@ std::unique_ptr getTransposeHeuristics( << "max_io_dtype_size: " << max_io_dtype_size << "\n" << "group 1: " << ir_utils::toString(grouped_inputs_outputs[0]) << "\n" - << "reference1: " << reference1 << "\n" + << "reference1: " << reference1->toString() << "\n" << "inner_most_id1 position: " << inner_most_pos1_in_ref1 << " (in reference 1)\n" << "group 2: " << ir_utils::toString(grouped_inputs_outputs[1]) << "\n" - << "reference2: " << reference2 << "\n" + << "reference2: " << reference2->toString() << "\n" << "inner_most_id2 position: " << inner_most_pos2_in_ref1 << " (in reference 1)" << std::endl; if (hasSmallTransposeDimensions(tparams)) { @@ -1045,11 +827,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { int64_t max_dims = 0; for (auto inp : input_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); + max_dims = std::max(scheduler_utils::nLogicalDims(inp), max_dims); } for (auto out : output_tvs) { - max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); + max_dims = std::max(scheduler_utils::nLogicalDims(out), max_dims); } // If everything is zero dim tensors, just return. @@ -1057,7 +839,7 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { return; } - DomainMap domain_map(fusion); + scheduler_tools::TransposeDomainMap domain_map(fusion); auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); NVF_ERROR(grouped_inputs_outputs.size() >= 2); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 667ea7f40f2..cd22d935a52 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1198,9 +1198,9 @@ std::vector cacheInputs(Fusion* fusion, bool unroll) { for (auto tv : in_tvs) { if (tv->uses().empty() || ir_utils::isTorchGatherLookupTv(tv) || ir_utils::isIndexSelectLookupTv(tv) || - ir_utils::isTvUsedByOpsOfType(tv)) { - // Right now, tensors that are input to the slice, select, and pad ops - // can't be cached as they must be in global memory. + ir_utils::isTvUsedByOpsOfType(tv)) { + // Right now, tensors that are input to the select, gather and + // index_select ops can't be cached as they must be in global memory. continue; } @@ -1214,7 +1214,7 @@ std::vector cacheInputs(Fusion* fusion, bool unroll) { // caching load instructions. std::vector cached_uses; for (auto use : tv->uses()) { - if (!use->isA()) { + if (!use->isOneOf()) { cached_uses.push_back(use); } } @@ -1577,14 +1577,6 @@ std::vector getInputsOutputsWithInnerDim( continue; } - // Slice op is explicitly not enabled for vectorized load. - if (std::all_of( - input_tv->uses().begin(), - input_tv->uses().end(), - [](Expr* e) -> bool { return e->isA(); })) { - continue; - } - if (hasInnerDim(input_tv, vectorizable_dims, vectorize_pass)) { vectorizable_tensors.push_back(input_tv); } @@ -2661,6 +2653,19 @@ void moveNonConcretizedBroadcastInnermost( } } +int64_t reorderDevicesToOuter(TensorView* tv) { + int64_t reorder_pos = 0; + std::unordered_map old2new; + for (const auto i : c10::irange(tv->getLoopDomain().size())) { + if (tv->axis((int64_t)i)->isDeviceDim()) { + old2new.emplace((int64_t)i, reorder_pos); + ++reorder_pos; + } + } + tv->reorder(old2new); + return (int64_t)old2new.size(); +} + } // namespace scheduler_utils } // namespace nvfuser diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 77317cde31b..62a359816d2 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -108,12 +108,12 @@ inline int64_t safeDiv(const int64_t x, const int64_t y) { // `to_update` to the positions in the splitted tensor. Splitting one dimension // multiple times is supported, and if this is the case, then the order of // `to_split` matters. All given dimensions are numbers before any split. -NVF_API void splitDims( +void splitDims( TensorView* tv, std::vector> to_split, // (dim, size) std::vector& to_update); -NVF_API inline void splitDims( +inline void splitDims( TensorView* tv, std::vector> to_split) { // (dim, size) std::vector unused; @@ -126,7 +126,7 @@ NVF_API inline void splitDims( // merge. // NOTE: merged is done as the entries in the order of `to_merge`, assuming an // order from inner to outer -NVF_API std::optional mergeDims( +std::optional mergeDims( TensorView* tv, std::vector to_merge, std::vector& to_update); @@ -153,7 +153,7 @@ int64_t mergeNonReduction(TensorView* tv); // DAG. Empty `selected_tvs` means selecting all tensors in the fusion of // `reference_tv`. `selected_parallel_types` are the selected parallel types. // Empty `selected_parallel_types` means selecting all parallel types. -NVF_API void parallelizeAllLike( +void parallelizeAllLike( TensorView* reference_tv, int64_t pos = -1, std::vector selected_tvs = {}, @@ -237,7 +237,7 @@ struct PersistentBufferInfo { // return inputs as being marked persistent if they follow this pattern. It is // important to note however inputs don't strictly have to be persistent as they // can simply be read multiple times from GMEM in the same kernel. -NVF_API PersistentBufferInfo persistentBuffers(Fusion* fusion); +PersistentBufferInfo persistentBuffers(Fusion* fusion); // A persistent tv can be projected to its producers when all the producers are // persistent tvs and there is no reduction op. @@ -304,7 +304,7 @@ struct PersistentBufferSizeReturn { // persistently, only based on buffers that must be persistent, and based on the // maximum of all minimum size requirement. i.e. if must be persistent, only // hold persistent dimension. -NVF_API PersistentBufferSizeReturn persistentBufferSize( +PersistentBufferSizeReturn persistentBufferSize( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, const PersistentBufferInfo& persistent_buffers, @@ -321,7 +321,7 @@ std::pair canonicalDimReduction( // Return a list of tensor views that are outputs of reduction operations, // excluding resharding reduce expressions. If multiple outputs of an expression // are found, only include one in the list -NVF_API std::vector getReductionTvs(Fusion* fusion); +std::vector getReductionTvs(Fusion* fusion); // Returns a list of TensorViews that are the consumer tv for a view operation. std::vector getViewTVs(Fusion* fusion); @@ -330,15 +330,15 @@ std::vector getViewTVs(Fusion* fusion); std::vector getTVsWithNonReductionRFactor(Fusion* fusion); // Reset inputs and outputs to global memory, everything else to local. -NVF_API void clearMemorySpace(Fusion* fusion); +void clearMemorySpace(Fusion* fusion); // Returns cached after tensors of the fusion inputs if unrolled. Otherwise // return empty vector. -NVF_API std::vector cacheInputs(Fusion* fusion, bool unroll); +std::vector cacheInputs(Fusion* fusion, bool unroll); // Returns the pairs of for // all outputs. -NVF_API std::vector> cacheAndForkOutputs( +std::vector> cacheAndForkOutputs( Fusion* fusion, bool unroll); @@ -473,7 +473,7 @@ struct BroadcastMultipleInformation { // // logical_reorder_map is provided to assume reference_tv will be reordered per // the map -NVF_API BroadcastMultipleInformation getBroadcastMultiples( +BroadcastMultipleInformation getBroadcastMultiples( TensorView* reference_tv, DataType index_type, const std::unordered_map& logical_reorder_map = {}); @@ -542,7 +542,7 @@ struct BoundedDirectionalTransformPropagator { //! Replay transforms from tensorview `from` //! to the tensorviews that are consumers //! of boundary tensorviews in `to` and producers of `from`. - NVF_API static void backward( + static void backward( TensorView* from, int64_t pos, std::vector to, @@ -601,13 +601,13 @@ struct BoundedDirectionalTransformPropagator { // If IterDomains are disjoint in the returned set, then they are considered // "separable". // Warning: This pass generates the IdGraphs, not intended for use at runtime. -NVF_API DisjointSets disjointLogicalSets(Fusion* fusion); +DisjointSets disjointLogicalSets(Fusion* fusion); // Makes sure that there are no group id's left of pos that match right of pos. // e.g. // [1, 0, 0] pos 2 would return false // [1, 0, 0] pos 1 would return true -NVF_API bool breakIsDisjoint(std::vector group_ids, int64_t pos); +bool breakIsDisjoint(std::vector group_ids, int64_t pos); // Generates an old to new map to reorder tv's domain as the logical order. // Priority is given to inner most dimensions for example: @@ -615,8 +615,7 @@ NVF_API bool breakIsDisjoint(std::vector group_ids, int64_t pos); // domain [i0*i2, i1] // will produce the map {{0, 1}, {1, 0}} // This is somewhat similar to orderTiledConcreteIdAsRoot -NVF_API std::unordered_map domainReorderAsLogicalMap( - TensorView* tv); +std::unordered_map domainReorderAsLogicalMap(TensorView* tv); // Generates an old to new map to reorder tv's domain as the logical order. // This only handles the simple case where allocation is a permutation of @@ -629,7 +628,7 @@ std::unordered_map maybeLogicalReorderAsAllocationMap( void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map); //! Check if tv is an output of a fastest-dim reduction -NVF_API bool isFastestDimReduction(TensorView* tv); +bool isFastestDimReduction(TensorView* tv); // A wrapper for Fusion::rotateLoop that provide more consistent interace inline void rotateLoop( @@ -670,7 +669,7 @@ inline void rotateLoop( //! tv1, but the data dependency for the resize op is still satisfied //! by having a copy of tv1, i.e., tv4. Note that the other op using //! tv1 still uses tv1. -NVF_API void prepareForMemoryTypePromotion(Fusion* fusion); +void prepareForMemoryTypePromotion(Fusion* fusion); //! If a consumer tensor induces a data dependency between threads, //! move its producer to a shared memory that is sufficient to satisfy @@ -678,13 +677,13 @@ NVF_API void prepareForMemoryTypePromotion(Fusion* fusion); //! with blockIdx, the producer memory type will be changed to //! Global. A proper RAW sync will be automatically inserted when the //! fusion is lowered. -NVF_API void promoteProducerMemoryTypes( +void promoteProducerMemoryTypes( Fusion* fusion, const std::vector& input_caches); //! Get all tensors that are connected to from_tvs without going through //! any tvs in the cutoff_tv_set. -NVF_API std::unordered_set getAllTvsFrom( +std::unordered_set getAllTvsFrom( const std::vector& from_tvs, const std::unordered_set& cutoff_tv_set); @@ -729,5 +728,22 @@ void moveNonConcretizedBroadcastInnermost( Fusion* fusion, const std::unordered_set& ignored_tvs = {}); +// Reorder DID parallelized axes to outermost positions. Returns +// the position of the outermost non-DID axis. +int64_t reorderDevicesToOuter(TensorView* tv); + +// Returns number of non-reduction/non-broadcas/non-device dims in logical +// domain +inline int64_t nLogicalDims(const TensorView* tv) { + auto logical_dom = tv->getLogicalDomain(); + int64_t tv_n_dims = 0; + for (auto dim : logical_dom) { + if (!dim->isReduction() && !dim->isBroadcast() && !dim->isDeviceDim()) { + tv_n_dims++; + } + } + return tv_n_dims; +} + } // namespace scheduler_utils } // namespace nvfuser diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index 3dcad2b497d..ab1756df1c8 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include @@ -55,30 +56,6 @@ Val* ContiguousInnerDimensionsMapper::isFullyProjected(IterDomain* id) { getProjectedExtent(id), commonOrConstExtent(ca_map_, id)); } -void ContiguousInnerDimensionsMapper::initializeResizeInfo(Fusion* fusion) { - auto exprs = fusion->exprs(); - for (auto* pad_op : ir_utils::filterByType(exprs)) { - if (!pad_op->out()->isA()) { - continue; - } - - auto* out_tv = pad_op->out()->as(); - - auto consumer_exprs = StmtSort::getExprsBetween( - {out_tv->getMaybeRootDomain().begin(), - out_tv->getMaybeRootDomain().end()}, - {out_tv->getLogicalDomain().begin(), out_tv->getLogicalDomain().end()}); - - // NOTE: if we can assume that PadOp is always on inputs, then we can skip - // to innermost resize instead. - auto resize_ops = ir_utils::filterByType(consumer_exprs); - std::copy( - resize_ops.begin(), - resize_ops.end(), - std::inserter(resize_in_pad_, resize_in_pad_.end())); - } -} - ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper( TensorView* reference, const std::vector& ids, @@ -92,8 +69,6 @@ ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper( divisible_splits_(divisible_splits) { FusionGuard fg(reference->fusion()); - initializeResizeInfo(reference->fusion()); - // Exclude reduction IDs if the reference is a fusion input as they // don't manifest at all in the fusion. This simplifies the // analysis in getContigMergeOfInnerSize, which only looks at @@ -400,44 +375,51 @@ std::vector ContiguousInnerDimensionsMapper::projectId( if (it == frontier.end()) { return; } - auto pos = std::distance(frontier.begin(), it); - if (resize_in_pad_.count(resize_op) != 0) { - // resize created by PadOp. - // project resize op to frontier. - frontier[pos] = id_to; - // clear left of resize, since those are no long contiguous. - frontier.erase(frontier.begin(), it); + // project resize op to frontier. + frontier[pos] = id_to; + // clear left of resize, since those are no long contiguous. + frontier.erase(frontier.begin(), it); - if (recording_) { - // TODO: support negative resize extent. - // - // Limit current support to only positive resize extent for now. So we - // only consider the pad_extent, which becomes the real buffer on - // output. Hence we do GCD among padded extent as well as extent of the - // id_from. Note since we are taking the GCD here, I don't think using - // id_from or id_to makes a difference. - auto consumer_factor = getProjectedExtent(id_from); - auto comp = [](Val* factor, Val* extent) { - return SimplifyingIrBuilder::whereExpr( - SimplifyingIrBuilder::eqExpr( - extent, extent->container()->zeroVal()), - factor, - // for extent < 0, we'll take max(1, extent). Because of the gcd, - // This is effectively excluding the resize id from vectorization. - SimplifyingIrBuilder::gcdExpr( - factor, - SimplifyingIrBuilder::maxExpr( - extent->container()->oneVal(), extent))); - }; - consumer_factor = comp(consumer_factor, resize_op->leftExpand()); - consumer_factor = comp(consumer_factor, resize_op->rightExpand()); - addProjectedExtent(id_to, consumer_factor); - } - } else { - // unsupproted resize. - frontier.erase(frontier.begin(), it + 1); + if (recording_) { + // we need to check slice offset at this point. + auto projected_extent = getProjectedExtent(id_from); + + // projected_extent == 0: return the projected_extent as-is + // resize_extent == 0 : no resizing, return the projected_extent as-is + // resize_extent != 0 : slicing/padding, return gcd(projected_extent, + // abs(resize_extent)) This is a conservative analysis of the offset for + // data accessing. A better approach needs to consider the actual start + // pointer address and handle it in alignment analysis in runtime info. We + // also need to consider multiple resize stacked together and how they + // could interact with each other. Translating this to code: if + // (resize_extent == 0 || projected_extent == 0) { + // return projected_extent; + // } else { + // gcd(projected_extent, abs(resize_extent)); + // } + auto comp = [](Val* projected_extent, Val* resize_extent) { + return SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::logicalOrExpr( + SimplifyingIrBuilder::eqExpr( + resize_extent, resize_extent->container()->zeroVal()), + SimplifyingIrBuilder::eqExpr( + projected_extent, + projected_extent->container()->zeroVal())), + projected_extent, + SimplifyingIrBuilder::gcdExpr( + projected_extent, IrBuilder::absExpr(resize_extent))); + }; + projected_extent = comp(projected_extent, resize_op->leftExpand()); + projected_extent = comp(projected_extent, resize_op->rightExpand()); + + // cap extent by the destination, this is useful when the id_to is resized + // to zero, where projected_extent shouldn't go beyond the total extent. + projected_extent = + SimplifyingIrBuilder::minExpr(projected_extent, id_to->extent()); + + addProjectedExtent(id_to, projected_extent); } }; @@ -848,14 +830,14 @@ std::vector> getTvToContigInnerSizeMapsOf( TensorView* ref, const std::unordered_map& logical_reorder_map) { std::vector> mappers; - auto root_dom = ref->getLogicalDomain(); + auto logical_dom = ref->getLogicalDomain(); if (!logical_reorder_map.empty()) { - root_dom = TensorDomain::orderedAs(root_dom, logical_reorder_map); + logical_dom = TensorDomain::orderedAs(logical_dom, logical_reorder_map); } - while (!root_dom.empty()) { - mappers.push_back(ContiguousInnerDimensionsMapper::map(ref, root_dom) + while (!logical_dom.empty()) { + mappers.push_back(ContiguousInnerDimensionsMapper::map(ref, logical_dom) .getTvToContigMergeOfInnerSizeMap()); - root_dom.erase(root_dom.begin()); + logical_dom.erase(logical_dom.begin()); } return mappers; } diff --git a/csrc/scheduler/vectorize_helper.h b/csrc/scheduler/vectorize_helper.h index d5c8e26c406..2ccc10fc52f 100644 --- a/csrc/scheduler/vectorize_helper.h +++ b/csrc/scheduler/vectorize_helper.h @@ -289,9 +289,6 @@ class NVF_API ContiguousInnerDimensionsMapper void propagateP2C(TensorView* from, TensorView* to) final; void propagateSibling(TensorView* from, TensorView* to) final; - // traverse fusion to mark the origin of Resize - void initializeResizeInfo(Fusion* fusion); - // Initialized to false, series of compute... calls will be performed to find // the spanning tree. Then propagate... calls will call the compute... calls. // recording_ starts as false, and stays that way during the first series of @@ -311,9 +308,6 @@ class NVF_API ContiguousInnerDimensionsMapper tv_infos_; std::unordered_map projected_extent_; - - //! stores all Resize* op that's added from PadOp* - std::unordered_set resize_in_pad_; }; // logical_reorder_map is provided to assume reference_tv will be reordered per diff --git a/csrc/utils.cpp b/csrc/utils.cpp index 4dcf877c1e9..094e86762b6 100644 --- a/csrc/utils.cpp +++ b/csrc/utils.cpp @@ -50,9 +50,10 @@ std::string debug_str(const c10::IValue& val) { std::string debug_str(const at::Tensor& tensor) { std::stringstream ss; ss << "Tensor:"; - ss << " device: " << tensor.device(); + ss << " shape: " << tensor.sizes(); ss << ", dtype: " << tensor.dtype(); - ss << ", shape: " << tensor.sizes(); + ss << ", device: " << tensor.device(); + ss << ", pointer: " << reinterpret_cast(tensor.data_ptr()); if (!tensor.is_contiguous()) { ss << ", strides: " << tensor.strides(); diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h index 132f8bb6a8d..970338c209c 100644 --- a/csrc/val_graph_visitor.h +++ b/csrc/val_graph_visitor.h @@ -216,4 +216,44 @@ class ValGraphBFS : public BFS< } }; +class ValGraphPermissiveBFS : public BFSWithPermissiveDependence< + ExprGroup, + ValGroup, + ValGraphDefinitions, + ValGraphUses, + ValGraphInputs, + ValGraphOutputs> { + public: + ValGraphPermissiveBFS( + const ValGraph& graph, + std::vector from_groups, + std::vector to_groups, + bool require_all_to_visited = true, + Direction allowed_direction = Direction::Undefined) + : BFSWithPermissiveDependence( + ValGraphDefinitions(graph), + ValGraphUses(graph), + ValGraphInputs(graph), + ValGraphOutputs(graph), + std::move(from_groups), + std::move(to_groups), + require_all_to_visited, + allowed_direction) {} + + // Just a shortcut to the generic getExprsBetween + static std::pair getExprGroupsBetween( + const ValGraph& graph, + const ValGroups& from, + const ValGroups& to, + bool require_all_to_visited = true, + Direction allowed_direction = Direction::Undefined) { + return getExprsBetween( + from.vector(), + to.vector(), + require_all_to_visited, + allowed_direction, + graph); + } +}; + } // namespace nvfuser diff --git a/nvfuser/pytorch_utils.py b/nvfuser/pytorch_utils.py index 7ad7c0c3e26..8ba45490377 100644 --- a/nvfuser/pytorch_utils.py +++ b/nvfuser/pytorch_utils.py @@ -43,23 +43,6 @@ def torch_dtype_to_nvfuser_dtype(dtype: Union[torch.dtype, NumberTypeType]): return _torch_dtype_to_nvfuser_dtype_map[dtype] -def patch_codegen_so(): - """ - Replace libnvfuser_codegen.so installed along with torch - """ - import torch - import shutil - import os - - dst_dir = os.path.join(os.path.dirname(torch.__file__), "lib") - src_dir = os.path.join(os.path.dirname(__file__), "lib") - - shutil.copyfile( - os.path.join(src_dir, "libnvfuser_codegen.so"), - os.path.join(dst_dir, "libnvfuser_codegen.so"), - ) - - def get_device_properties() -> Tuple[int, float]: """ Computes device properties using ctypes and cuda. diff --git a/runtime/cluster.cu b/runtime/cluster.cu new file mode 100644 index 00000000000..ca0e7f91b31 --- /dev/null +++ b/runtime/cluster.cu @@ -0,0 +1,89 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + +// The optional .relaxed qualifier on barrier.cluster.arrive specifies that +// there are no memory ordering and visibility guarantees provided for the +// memory accesses performed prior to barrier.cluster.arrive. +void clusterArriveRelaxed() { + asm volatile("barrier.cluster.arrive.relaxed.aligned;" : :); +} + +// A thread arrives at barrier but it does not have to wait for threads in other +// participating warps. +void clusterArrive() { + asm volatile("barrier.cluster.arrive.aligned;" : :); +} + +// A thread waits for all non-exited threads of the cluster to perform +// cluster_arrive. +void clusterWait() { + asm volatile("barrier.cluster.wait.aligned;" : :); +} + +// Synchronize threads in cluster +void clusterSync() { + cluster_arrive(); + cluster_wait(); +} + +// Returns the dim3 grid size in terms of number of clusters. +dim3 clusterGridDims() { + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%nclusterid.x;" : "=r"(x) :); + asm volatile("mov.u32 %0, %%nclusterid.y;" : "=r"(y) :); + asm volatile("mov.u32 %0, %%nclusterid.z;" : "=r"(z) :); + return {x, y, z}; +} + +// Returns the dim3 cluster rank in the grid. +dim3 clusterIdInGrid() { + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%clusterid.x;" : "=r"(x) :); + asm volatile("mov.u32 %0, %%clusterid.y;" : "=r"(y) :); + asm volatile("mov.u32 %0, %%clusterid.z;" : "=r"(z) :); + return {x, y, z}; +} + +// Returns the relative dim3 block rank local to the cluster. +dim3 blockIdInCluster() { + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_ctaid.x;" : "=r"(x) :); + asm volatile("mov.u32 %0, %%cluster_ctaid.y;" : "=r"(y) :); + asm volatile("mov.u32 %0, %%cluster_ctaid.z;" : "=r"(z) :); + return {x, y, z}; +} + +// Returns the dim3 cluster shape. +dim3 clusterShape() { + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_nctaid.x;" : "=r"(x) :); + asm volatile("mov.u32 %0, %%cluster_nctaid.y;" : "=r"(y) :); + asm volatile("mov.u32 %0, %%cluster_nctaid.z;" : "=r"(z) :); + return {x, y, z}; +} + +// Get 1D ctaid in a cluster. +uint32_t blockRankInCluster() { + uint32_t rank; + asm volatile("mov.u32 %0, %%cluster_ctarank;" : "=r"(rank) :); + return rank; +} + +// Set the destination block-ID in cluster for a given SMEM Address +uint32_t mapSharedRank(uint32_t smemAddr, uint32_t rank) { + uint32_t result; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;" + : "=r"(result) + : "r"(smemAddr), "r"(rank)); + return result; +} + +#endif // Arch 90 diff --git a/tests/cpp/multidevice_transformer.cpp b/tests/cpp/multidevice_transformer.cpp new file mode 100644 index 00000000000..fe552b6606a --- /dev/null +++ b/tests/cpp/multidevice_transformer.cpp @@ -0,0 +1,660 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include +#include + +namespace nvfuser { +namespace { +// TODO: These linearBackwards helper functions can be merged once +// we do not have logically split rfactor domain. +struct LinearBackwardsResult { + TensorView* grad_x; + TensorView* grad_w; + TensorView* grad_b; +}; + +// x format: [i0, i1] dtype +// weight format: [DID(D), i2/D, i1] dtype +// grad format: [DID(D) i0, i2/D] float or dtype +// outputs: grad_x [i0, i1] dtype +// grad_w [DID i2/D, i1] dtype +// grad_b [DID i2/2] dtype +LinearBackwardsResult linearBackwards( + TensorView* x, + TensorView* w, + TensorView* grad) { + DataType dtype = w->dtype(); + TensorView* grad_f = maybeCastOp(DataType::Float, grad); + TensorView* grad_q = maybeCastOp(dtype, grad); + TensorView* grad_x_partials = matmul(grad_q, w); + TensorView* grad_x = sum(grad_x_partials, {0}); // allreduce + TensorView* grad_q_t = transpose(grad_q, 1, 2); + TensorView* grad_w = matmul(grad_q_t, x); + TensorView* grad_b = sum(grad_f, {1}); + grad_b = castOp(dtype, grad_b); + + return {grad_x, grad_w, grad_b}; +} + +// x format: [DID, i0, i1/D] dtype +// weight format: [DID, i2, i1/D] dtype +// grad format: [i0, i2] float +// outputs: grad_x [DID i0, i1/D] dtype +// grad_w [DID, i2, i1/D] dtype +// grad_b [i2] dtype +LinearBackwardsResult shardedLinearBackwards( + TensorView* x, + TensorView* w, + TensorView* grad) { + DataType dtype = w->dtype(); + TensorView* grad_q = castOp(dtype, grad); + TensorView* grad_x = matmul(grad_q, w); + TensorView* grad_t = transpose(grad_q, 0, 1); + TensorView* grad_w = matmul(grad_t, x); + TensorView* grad_b = sum(grad, {0}); + grad_b = castOp(dtype, grad_b); + + return {grad_x, grad_w, grad_b}; +} + +// Forward layer_norm with cached mean_bcast and invstd tensors to avoid +// recomputing Welford. For use in backwards pass. +TensorView* layerNormWithCachedStats( + TensorView* x, + TensorView* mean_bcast, + TensorView* invstd, + const std::vector& norm_shape, + TensorView* weight, + TensorView* bias) { + const int64_t kNumberOfDims = + (int64_t)TensorDomain::noReductions(x->getLogicalDomain()).size(); + const int64_t kOuterNumDims = kNumberOfDims - norm_shape.size(); + std::vector outer_broadcast_mask(kNumberOfDims, false); + for (const auto idx : c10::irange(kOuterNumDims)) { + outer_broadcast_mask[idx] = true; + } + + auto x_sub_mean = sub(x, mean_bcast); + auto y = mul(x_sub_mean, invstd); + + auto weight_bcast = broadcast(weight, outer_broadcast_mask); + y = mul(y, weight_bcast); + auto bias_bcast = broadcast(bias, outer_broadcast_mask); + return add(y, bias_bcast); +} +} // namespace + +MlpResult DistributedTransformer::mlp( + TensorView* x, + TensorView* w0, + TensorView* b0, + TensorView* w1, + TensorView* b1, + const DeviceMesh& mesh, + bool sequence_parallel) { + const DataType dtype = w0->dtype(); + + if (sequence_parallel) { + // Input arrives sharded and must be allgathered back + x->setDeviceMesh(mesh); + x->axis(0)->parallelize(ParallelType::DIDx); + x = set(x); // allgather + x->axis(0)->parallelize(ParallelType::Serial); + // Reshape back to 2D. This is uncessary except to keep + // the shapes of linear0 the same for TP and TP+SP. + x = reshape(x, {D, B * S / D, E}, {B * S, E}); + } + // Linear 0 + TensorView* linear0 = linear(x, w0, b0); + // GeLU + TensorView* gelu = tanh_gelu(castOp(DataType::Float, linear0)); + gelu = castOp(dtype, gelu); + // Linear 1 + TensorView* local_matmul1 = matmul(gelu, transpose(w1, 1, 2)); + if (sequence_parallel) { + // Remove after https://github.com/NVIDIA/Fuser/issues/2563 + // Reshape to explicitly pull the sharded axis into the logical domain + local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E}); + } + TensorView* matmul1 = sum(local_matmul1, {0}); // Allreduce or Reduce scatter + std::vector bcast_mask(matmul1->nDims() - 1, true); + bcast_mask[matmul1->nDims() - 2] = false; + TensorView* linear1 = add(matmul1, broadcast(b1, bcast_mask)); + // Dropout + Val* prob = IrBuilder::create(1.0 - kDropoutProb); + Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); + TensorView* dropout_result = dropout(linear1, prob, scale).output; + + // Tensor parallel shardings + for (auto* tv : {w0, b0, w1}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + for (auto* tv : {x, b1}) { + tv->setDeviceMesh(mesh); + } + + // Sequence parallel shardings + if (sequence_parallel) { + matmul1->setDeviceMesh(mesh); + matmul1->axis(1)->parallelize(ParallelType::DIDx); + } + + return {linear0, gelu, matmul1, linear1, dropout_result}; +} + +MhaResult DistributedTransformer::mha( + TensorView* x, + TensorView* w0, + TensorView* b0, + TensorView* w1, + TensorView* b1, + const DeviceMesh& mesh, + bool sequence_parallel) { + auto dtype = w0->dtype(); + + if (sequence_parallel) { + // Input arrives sharded and must be allgathered back + x->setDeviceMesh(mesh); + x->axis(0)->parallelize(ParallelType::DIDx); + x = set(x); // allgather + x->axis(0)->parallelize(ParallelType::Serial); + // Reshape is uncessary, it is here to keep shapes with TP and TP+SP the + // same for validation. + x = reshape(x, {D, B * S / D, E}, {B * S, E}); + } + + TensorView* linear0 = linear(x, w0, b0); + // Forming the q,k,v vectors: + TensorView* qkv_cat = + reshape(linear0, {D, B * S, 3 * E / D}, {D, B, S, 3 * E / D}); + std::vector qkv = chunk(qkv_cat, 3, -1); + for (auto i : c10::irange(3)) { + qkv[i] = reshape(qkv[i], {D, B, S, E / D}, {D, B, S, H / D, E / H}); + qkv[i] = transpose(qkv[i], 2, 3); + } + // SDPA + SdpfaFwdResult sdpa = sdpfa_fwd( + qkv[0], + qkv[1], + qkv[2], + IrBuilder::create(kSdpaProb), + IrBuilder::create(true), + IrBuilder::create(kSdpaScale)); + TensorView* sdpa_output = sdpa.output; + // Linear 1 + TensorView* sdpa_transpose = transpose(sdpa_output, 2, 3); + TensorView* sdpa_reshape = + reshape(sdpa_transpose, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + TensorView* local_matmul1 = matmul(sdpa_reshape, transpose(w1, 1, 2)); + if (sequence_parallel) { + // Remove after https://github.com/NVIDIA/Fuser/issues/2563 + // Reshape to explicitly pull the sharded axis into the logical domain + local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E}); + } + TensorView* matmul1 = sum(local_matmul1, {0}); // allreduce + std::vector bcast_mask(matmul1->nDims() - 1, true); + bcast_mask[matmul1->nDims() - 2] = false; + TensorView* linear1 = add(matmul1, broadcast(b1, bcast_mask)); + // Dropout + Val* prob = IrBuilder::create(1.0 - kDropoutProb); + Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); + TensorView* dropout_result = dropout(linear1, prob, scale).output; + + // Tensor parallel shardings + for (auto tv : {x, b1}) { + tv->setDeviceMesh(mesh); + } + for (auto tv : {w0, b0, w1}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + // Sequence parallel sharding. + if (sequence_parallel) { + matmul1->setDeviceMesh(mesh); + matmul1->axis(1)->parallelize(ParallelType::DIDx); + } + + return {linear0, sdpa_output, matmul1, linear1, dropout_result}; +} + +std::vector DistributedTransformer::mlp_backwards( + TensorView* grad, + TensorView* x, + TensorView* mask, + TensorView* w0, + TensorView* w1, + TensorView* linear0, + const DeviceMesh& mesh) { + DataType dtype = w0->dtype(); + + // Activation recomputation: Always recompute gelu + TensorView* gelu = castOp(dtype, tanh_gelu(castOp(DataType::Float, linear0))); + + // Backwards pass + const double kScale = 1.0 / (1.0 - kDropoutProb); + Val* dropout_scale = IrBuilder::create(kScale); + TensorView* dropout_grad = dropout_backward(grad, mask, dropout_scale); + auto linear1_grads = shardedLinearBackwards(gelu, w1, dropout_grad); + TensorView* matmul1_grad_x_ = castOp(DataType::Float, linear1_grads.grad_x); + TensorView* gelu_grad = tanh_gelu_backward(matmul1_grad_x_, linear0); + auto linear0_grads = linearBackwards(x, w0, gelu_grad); + + // Manaul sharding annotations + for (auto tv : + {x, + grad, + mask, + dropout_grad, + linear1_grads.grad_b, + linear0_grads.grad_x}) { + tv->setDeviceMesh(mesh); + } + + for (auto tv : + {w0, + w1, + linear0, + linear1_grads.grad_x, + linear1_grads.grad_w, + gelu_grad, + linear0_grads.grad_w, + linear0_grads.grad_b}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + + std::vector outputs = { + dropout_grad, + linear1_grads.grad_w, + linear1_grads.grad_b, + gelu_grad, + linear0_grads.grad_w, + linear0_grads.grad_b, + linear0_grads.grad_x}; + return outputs; +} + +std::vector DistributedTransformer::mha_backwards( + TensorView* x, + TensorView* w0, + TensorView* w1, + TensorView* mask, + TensorView* sdpa_output, + TensorView* sdpa_log_sumexp, + TensorView* sdpa_seed, + TensorView* sdpa_offset, + TensorView* grad, + TensorView* linear0, + const DeviceMesh& mesh) { + DataType dtype = w0->dtype(); + // Reform qkv from linear0 output + TensorView* qkv_cat = reshape( + castOp(DataType::Float, linear0), + {D, B * S, 3 * E / D}, + {D, B, S, 3 * E / D}); + std::vector qkv = chunk(qkv_cat, 3, -1); + for (auto i : c10::irange(3)) { + qkv[i] = reshape(qkv[i], {D, B, S, E / D}, {D, B, S, H / D, E / H}); + qkv[i] = transpose(qkv[i], 2, 3); + qkv[i] = castOp(dtype, qkv[i]); + qkv[i]->setDeviceMesh(mesh); + qkv[i]->axis(0)->parallelize(ParallelType::DIDx); + } + + // dropout backwards + const double kScale = 1.0 / (1.0 - kDropoutProb); + auto dropout_scale = IrBuilder::create(kScale); + TensorView* dropout_grad = dropout_backward(grad, mask, dropout_scale); + + // linear1 backwards + TensorView* sdpa_output_reshape = + transpose(sdpa_output, 2, 3); // D, B, S, H/D, E/H + sdpa_output_reshape = + reshape(sdpa_output_reshape, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + auto linear1_grads = + shardedLinearBackwards(sdpa_output_reshape, w1, dropout_grad); + + // SDPA backwards + TensorView* linear1_x_grad = + reshape(linear1_grads.grad_x, {D, B * S, E / D}, {D, B, S, H / D, E / H}); + linear1_x_grad = transpose(linear1_x_grad, 2, 3); // D, B, H/D, S, E/H + // Explicitly shard inputs before SDPA backward node + for (auto tv : {linear1_x_grad, sdpa_output, sdpa_log_sumexp}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + auto sdpa_grad = sdpfa_bwd( + linear1_x_grad, + qkv[0], + qkv[1], + qkv[2], + sdpa_output, + sdpa_log_sumexp, + /*dropout_p=*/IrBuilder::create(kSdpaProb), + /*is_causal=*/IrBuilder::create(true), + sdpa_seed, + sdpa_offset, + /*scale=*/IrBuilder::create(kSdpaScale)); + + TensorView* q_grad = transpose(sdpa_grad.grad_query, 2, 3); + q_grad = reshape(q_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + TensorView* v_grad = transpose(sdpa_grad.grad_value, 2, 3); + v_grad = reshape(v_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + TensorView* k_grad = transpose(sdpa_grad.grad_key, 2, 3); + k_grad = reshape(k_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + TensorView* kqv_grad = cat({k_grad, q_grad, v_grad}, -1); + auto linear0_grads = linearBackwards(x, w0, kqv_grad); + + for (auto tv : + {x, + mask, + grad, + dropout_grad, + linear1_grads.grad_b, + linear0_grads.grad_x}) { + tv->setDeviceMesh(mesh); + } + for (auto tv : + {w0, + w1, + sdpa_output, + sdpa_log_sumexp, + linear0, + linear1_grads.grad_x, + linear1_grads.grad_w, + linear0_grads.grad_w, + linear0_grads.grad_b, + sdpa_grad.grad_query, + sdpa_grad.grad_key, + sdpa_grad.grad_value}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + return { + dropout_grad, + linear1_grads.grad_w, + linear1_grads.grad_b, + sdpa_grad.grad_query, + sdpa_grad.grad_key, + sdpa_grad.grad_value, + linear0_grads.grad_w, + linear0_grads.grad_b, + linear0_grads.grad_x}; +} + +std::unique_ptr DistributedTransformer::forward( + DataType dtype, + bool sequence_parallel) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const auto mesh = DeviceMesh::createForNumDevices(D); + + TensorView* x = sequence_parallel + ? makeContigConcreteTensor({D, B * S / D, E}, dtype) + : makeContigConcreteTensor({B * S, E}, dtype); + TensorView* ln0_w = makeContigTensor(1); + TensorView* ln0_b = makeContigTensor(1); + TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); + TensorView* mha_b0 = makeContigConcreteTensor({D, 3 * E / D}, dtype); + TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); + TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype); + TensorView* ln1_w = makeContigTensor(1); + TensorView* ln1_b = makeContigTensor(1); + TensorView* mlp_w0 = makeContigConcreteTensor({D, 4 * E / D, E}, dtype); + TensorView* mlp_b0 = makeContigConcreteTensor({D, 4 * E / D}, dtype); + TensorView* mlp_w1 = makeContigConcreteTensor({D, E, 4 * E / D}, dtype); + TensorView* mlp_b1 = makeContigConcreteTensor({E}, dtype); + + fusion->addInput(x); + fusion->addInput(ln0_w); + fusion->addInput(ln0_b); + fusion->addInput(mha_w0); + fusion->addInput(mha_b0); + fusion->addInput(mha_w1); + fusion->addInput(mha_b1); + fusion->addInput(ln1_w); + fusion->addInput(ln1_b); + fusion->addInput(mlp_w0); + fusion->addInput(mlp_b0); + fusion->addInput(mlp_w1); + fusion->addInput(mlp_b1); + + constexpr float kEps = 1e-5; + auto eps = IrBuilder::create(kEps); + std::vector norm_shape{E}; + + auto ln_input = castOp(DataType::Float, x); + auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps); + auto mha_in = castOp(dtype, ln0.output); + auto mha_tvs = + mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh, sequence_parallel); + auto resid0 = add(ln_input, mha_tvs.output); + auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps); + auto mlp_in = castOp(dtype, ln1.output); + auto mlp_tvs = + mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh, sequence_parallel); + auto resid1 = add(resid0, mlp_tvs.output); + resid1 = castOp(dtype, resid1); + + fusion->addOutput(ln0.output); + fusion->addOutput(mha_tvs.output); + fusion->addOutput(ln1.output); + fusion->addOutput(mlp_tvs.output); + fusion->addOutput(resid1); + + x->setDeviceMesh(mesh); + if (sequence_parallel) { + // Input arrives sharded + x->axis(0)->parallelize(ParallelType::DIDx); + // Propagate SP shardings from x through layernorms, dropouts, residual + // adds. Even though mha_in is part of the boundary set, residuals allow the + // shardings to propagate up the graph so we must cut off the propagation at + // the outputs of reduce scatters (mha and mlp matmul1) + shardBetween({x}, {mha_in, mlp_in, mha_tvs.matmul1, mlp_tvs.matmul1}, x); + // Propagate TP sharding for MLP and MHA from sharded weights. We do not + // need to shard from mha_b0 or mlp_b0 because they are only consumed by + // their respective linear0 expression which is sharded from *_w0. + shardBetween({mha_w0}, {mha_tvs.matmul1}, mha_w0); + shardBetween({mha_w1}, {mha_tvs.matmul1}, mha_w1); + shardBetween({mlp_w0}, {mlp_tvs.matmul1}, mlp_w0); + shardBetween({mlp_w1}, {mlp_tvs.matmul1}, mlp_w1); + } else { + // TP only shardings + // Layernorm, residuals, are all replicated like x. shardBetween + // shards all tvs reachable from x, so the input and output tvs must + // be in the boundary set. + shardBetween({x}, {mha_in, mha_tvs.output, mlp_in, mlp_tvs.output}, x); + // TP sharded regions within mha and mlp + shardBetween({mha_in}, {mha_tvs.output}, mha_w0); + shardBetween({mlp_in}, {mlp_tvs.output}, mlp_w0); + } + + return std::make_unique(std::move(fusion)); +} + +std::unique_ptr DistributedTransformer::backward( + DataType dtype) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const auto mesh = DeviceMesh::createForNumDevices(D); + std::vector norm_shape{E}; + + TensorView* x = makeContigConcreteTensor({B * S, E}, dtype); + TensorView* grad = makeContigTensor(2, dtype); + TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); + TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); + TensorView* mlp_w0 = makeContigTensor(3, dtype); + TensorView* mlp_w1 = makeContigTensor(3, dtype); + TensorView* mha_mask = makeContigTensor(2, DataType::Bool); + TensorView* mlp_mask = makeContigTensor(2, DataType::Bool); + TensorView* mha_sdpa_out = makeConcreteTensor({D, B, H / D, S, E / H}, dtype); + TensorView* mha_sdpa_log_sumexp = + makeContigConcreteTensor({D, B, H / D, S}, DataType::Float); + TensorView* mha_sdpa_seed = makeSymbolicTensor({}, DataType::Int); + TensorView* mha_sdpa_offset = makeSymbolicTensor({}, DataType::Int); + TensorView* ln1_w = makeContigTensor(1); + TensorView* ln1_b = makeContigTensor(1); + TensorView* ln1_mean = makeConcreteTensor({B * S, 1}); + TensorView* ln1_rstd = makeConcreteTensor({B * S, 1}); + TensorView* ln0_w = makeContigTensor(1); + TensorView* ln0_b = makeContigTensor(1); + TensorView* ln0_mean = makeConcreteTensor({B * S, 1}); + TensorView* ln0_rstd = makeConcreteTensor({B * S, 1}); + TensorView* mha_linear0 = makeContigTensor(3, dtype); + TensorView* mha_linear1 = makeContigTensor(2); + TensorView* mlp_linear0 = makeContigTensor(3, dtype); + + fusion->addInput(x); + fusion->addInput(grad); + fusion->addInput(mha_w0); + fusion->addInput(mha_w1); + fusion->addInput(mlp_w0); + fusion->addInput(mlp_w1); + fusion->addInput(mlp_mask); + fusion->addInput(mha_mask); + fusion->addInput(mha_sdpa_out); + fusion->addInput(mha_sdpa_log_sumexp); + fusion->addInput(mha_sdpa_seed); + fusion->addInput(mha_sdpa_offset); + fusion->addInput(ln1_w); + fusion->addInput(ln1_b); + fusion->addInput(ln1_mean); + fusion->addInput(ln1_rstd); + fusion->addInput(ln0_w); + fusion->addInput(ln0_b); + fusion->addInput(ln0_mean); + fusion->addInput(ln0_rstd); + fusion->addInput(mha_linear0); + fusion->addInput(mha_linear1); + fusion->addInput(mlp_linear0); + + // Activation recomputation: mlp gelu, dropouts, and + // partially recompute layer norms using cached statistics. + auto ln0_in = castOp(DataType::Float, x); + auto ln0 = layerNormWithCachedStats( + ln0_in, ln0_mean, ln0_rstd, norm_shape, ln0_w, ln0_b); + auto mha_in = castOp(dtype, ln0); + + Val* dropout_scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); + // Use input mha_mask to implement dropout + auto mha_out = mul(mha_linear1, mha_mask); + mha_out = mul(mha_out, dropout_scale); + auto resid0 = add(ln0_in, mha_out); + auto ln1 = layerNormWithCachedStats( + resid0, ln1_mean, ln1_rstd, norm_shape, ln1_w, ln1_b); + auto mlp_in = castOp(dtype, ln1); + + // Backwards + auto grad_float = castOp(DataType::Float, grad); + auto mlp_grads = mlp_backwards( + grad_float, mlp_in, mlp_mask, mlp_w0, mlp_w1, mlp_linear0, mesh); + auto ln1_grads = layer_norm_backward( + castOp(DataType::Float, mlp_grads[6]), + resid0, + norm_shape, + ln1_mean, + ln1_rstd, + ln1_w, + ln1_b, + {true, true, true}); + auto resid1_grad = add(ln1_grads.grad_input, grad_float); + auto mha_grads = mha_backwards( + mha_in, + mha_w0, + mha_w1, + mha_mask, + mha_sdpa_out, + mha_sdpa_log_sumexp, + mha_sdpa_seed, + mha_sdpa_offset, + resid1_grad, + mha_linear0, + mesh); + auto ln0_grads = layer_norm_backward( + castOp(DataType::Float, mha_grads[8]), + ln0_in, + norm_shape, + ln0_mean, + ln0_rstd, + ln0_w, + ln0_b, + {true, true, true}); + auto dx = add(ln0_grads.grad_input, resid1_grad); + dx = castOp(dtype, dx); + + fusion->addOutput(mlp_grads[1]); // mlp linear1 weight grad + fusion->addOutput(mlp_grads[2]); // mlp linear1 bias grad + fusion->addOutput(mlp_grads[4]); // mlp linear0 weight grad + fusion->addOutput(mlp_grads[5]); // mlp linear0 bias grad + fusion->addOutput(ln1_grads.grad_weight); + fusion->addOutput(ln1_grads.grad_bias); + fusion->addOutput(mha_grads[1]); // mha linear1 weight grad + fusion->addOutput(mha_grads[2]); // mha linear1 bias grad + fusion->addOutput(mha_grads[6]); // mha linear0 weight grad + fusion->addOutput(mha_grads[7]); // mha linear0 bias grad + fusion->addOutput(ln0_grads.grad_weight); + fusion->addOutput(ln0_grads.grad_bias); + fusion->addOutput(dx); // transformer grad input + + // Sharding annotations for input and output TVs not sharded + // by mlp_backward or mha_backward + for (auto* tv : + {ln0_w, + ln0_b, + ln0_mean, + ln0_rstd, + ln1_w, + ln1_b, + ln1_mean, + ln1_rstd, + ln1_grads.grad_weight, + ln1_grads.grad_bias, + ln0_grads.grad_weight, + ln0_grads.grad_bias, + ln0_grads.grad_input}) { + tv->setDeviceMesh(mesh); + } + + // Sharded inputs to outputs + shardBetween( + {mha_w0, mha_w1, mha_sdpa_out}, + {mha_grads[1], mha_grads[6], mha_grads[7]}, + mha_w0); + shardBetween( + {mlp_w0, mlp_w1}, {mlp_grads[1], mlp_grads[4], mlp_grads[5]}, mlp_w0); + + // Unsharded inputs to outputs + shardBetween( + {x, + grad, + mha_mask, + mlp_mask, + mha_linear1, + ln0_mean, + ln0_w, + ln0_b, + ln1_mean, + ln1_w, + ln1_b}, + {mlp_grads[2], + ln1_grads.grad_weight, + ln1_grads.grad_bias, + mha_grads[2], + ln0_grads.grad_weight, + ln0_grads.grad_bias, + dx}, + x); + + return std::make_unique(std::move(fusion)); +} +} // namespace nvfuser diff --git a/tests/cpp/multidevice_transformer.h b/tests/cpp/multidevice_transformer.h new file mode 100644 index 00000000000..33a3f759926 --- /dev/null +++ b/tests/cpp/multidevice_transformer.h @@ -0,0 +1,98 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include + +namespace nvfuser { + +struct MlpResult { + TensorView* linear0; + TensorView* gelu; + TensorView* matmul1; + TensorView* linear1; + TensorView* output; +}; + +struct MhaResult { + TensorView* linear0; + TensorView* sdpa; + TensorView* matmul1; + TensorView* linear1; + TensorView* output; +}; + +class DistributedTransformer { + public: + DistributedTransformer( + int64_t num_devices, + int64_t batch_size, + int64_t embedding_size, + int64_t number_heads, + int64_t sequence_length, + double dropout_prob = 0.1, + double sdpa_dropout_prob = 0.1) + : D(num_devices), + B(batch_size), + E(embedding_size), + H(number_heads), + S(sequence_length), + kDropoutProb(dropout_prob), + kSdpaProb(sdpa_dropout_prob) {} + + std::unique_ptr forward( + DataType dtype, + bool sequence_parallel = false); + std::unique_ptr backward(DataType dtype); + + MlpResult mlp( + TensorView* x, + TensorView* w0, + TensorView* b0, + TensorView* w1, + TensorView* b1, + const DeviceMesh& mesh, + bool sequence_parallel = false); + + MhaResult mha( + TensorView* x, + TensorView* w0, + TensorView* b0, + TensorView* w1, + TensorView* b1, + const DeviceMesh& mesh, + bool sequence_parallel = false); + + std::vector mlp_backwards( + TensorView* grad, + TensorView* x, + TensorView* mask, + TensorView* w0, + TensorView* w1, + TensorView* linear0, + const DeviceMesh& mesh); + + std::vector mha_backwards( + TensorView* x, + TensorView* w0, + TensorView* w1, + TensorView* mask, + TensorView* sdpa_output, + TensorView* sdpa_log_sumexp, + TensorView* sdpa_seed, + TensorView* sdpa_offset, + TensorView* grad, + TensorView* linear0, + const DeviceMesh& mesh); + + const int64_t D, B, E, H, S; + const double kDropoutProb; + const double kSdpaProb; + static constexpr double kSdpaScale = 1e-3; +}; +} // namespace nvfuser diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index e9fb2dcf503..630c77ac458 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -959,34 +959,6 @@ TEST_F(AliasTest, SourceIsBothInputAndOutput) { EXPECT_EQ(in_tensor.data_ptr(), out_tensors[1].data_ptr()); } -TEST_F(AliasTest, SegmentBoundary) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - TensorView* in = makeContigConcreteTensor({2, 3}); - TensorView* out = permute(in, {1, 0}); - // With the current segmentation algorithm, `slice` has to be the start of a - // fusion. So we expect `permute` to form a meta-op-only segment and the rest - // a pointwise segment. - out = slice(out, {0, 0}, {2, 2}); - out = add(out, out); - fusion->addInput(in); - fusion->addOutput(out); - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor in_tensor = at::randn({2, 3}).cuda(); - at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; - testValidate( - executor_cache.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); - - FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - EXPECT_THAT( - runtime->fusionSegments()->groups(), - UnorderedElementsAre( - HeuristicIs(SchedulerType::NoOp), - HeuristicIs(SchedulerType::PointWise))); -} - TEST_F(AliasTest, ReuseBuffer) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); diff --git a/tests/cpp/test_bfs.cpp b/tests/cpp/test_bfs.cpp index 1172fc86afd..500766207d7 100644 --- a/tests/cpp/test_bfs.cpp +++ b/tests/cpp/test_bfs.cpp @@ -474,7 +474,7 @@ TEST_F(BFSTest, TraversalDirection) { EXPECT_TRUE(backward_path.empty()) << "Actual: " << backward_path; } -// A simple test for IRBFSWithPermissiveDependence +// A simple test for BFSWithPermissiveDependence TEST_F(BFSTest, IRBFSPermissiveTraversal) { Fusion fusion; FusionGuard fg(&fusion); @@ -508,7 +508,7 @@ TEST_F(BFSTest, IRBFSPermissiveTraversal) { // to: [i4] // -> forward merge, forward split { - auto path = getExprsBetween( + auto path = getExprsBetween( {i0, i2}, {i4}, /*require_all_to_visited=*/false) .first; EXPECT_EQ(path.size(), 2); @@ -524,7 +524,7 @@ TEST_F(BFSTest, IRBFSPermissiveTraversal) { // to: [i1] // -> bwd split, bwd merge { - auto path = getExprsBetween( + auto path = getExprsBetween( {i4, i5}, {i1}, /*require_all_to_visited=*/false) .first; EXPECT_EQ(path.size(), 2); diff --git a/tests/cpp/test_gpu1.cpp b/tests/cpp/test_gpu1.cpp index 05faa1d5b60..093265b2dcf 100644 --- a/tests/cpp/test_gpu1.cpp +++ b/tests/cpp/test_gpu1.cpp @@ -2711,13 +2711,17 @@ TEST_F(NVFuserTest, FusionFp8CastOps_CUDA) { std::vector inputs = {input1}; KernelExecutor ke; - +#if (CUDA_VERSION >= 12010) + if (!deviceMajorMinorCheck(8, 9)) { +#elif (CUDA_VERSION >= 11080) if (!deviceMajorMinorCheck(9)) { +#else + if (true) { +#endif ASSERT_THAT( [&]() { ke.compile(&fusion, inputs); }, testing::ThrowsMessage(testing::HasSubstr( - "Reason: Fusion contains Float8_xxx values which was introduced in Hopper (9.0)"))); - GTEST_SKIP() << "skipping tests on pre-HOPPER GPUs"; + "Reason: Fusion contains Float8_xxx values"))); } else { ke.compile(&fusion, inputs); auto outputs = ke.run(inputs); diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 8fa235357b1..76d45f6de4c 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -8051,23 +8051,27 @@ TEST_F(NVFuserTest, AvoidCachingSliceInput) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(inputs); - // check segment and sliced tvs are not cached + // check segmentation and sliced tvs are not cached if not scheduled by + // the resize scheduler auto kernel_runtime = executor_cache.getMostRecentKernelRuntime(); - NVF_CHECK(kernel_runtime->isSegmented(), "segmentation didn't happen"); const auto num_segments = kernel_runtime->fusionSegments()->groups().size(); - NVF_CHECK(num_segments == 3, "Expect 3 segments, got: ", num_segments); - for (const auto& exec : kernel_runtime->executors()) { + EXPECT_EQ(num_segments, 3) << "Expect 3 segments, got: " << num_segments; + for (const auto i : c10::irange(kernel_runtime->executors().size())) { + const auto& exec = kernel_runtime->executors().at(i); if (!exec->isA()) { continue; } + if (kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(i) + ->scheduler_type == SchedulerType::Resize) { + continue; + } const auto* ke = exec->as(); for (auto expr : ke->fusion()->exprs()) { if (expr->isA()) { auto slice = expr->as(); - NVF_CHECK( - slice->in()->getMemoryType() == MemoryType::Global, - "slice input must be in global memory, get: ", - slice->in()->getMemoryType()); + EXPECT_EQ(slice->in()->getMemoryType(), MemoryType::Global); } } } @@ -9245,8 +9249,6 @@ TEST_F(NVFuserTest, AllIdsMultipleDependencies) { tv1->split(0, 4); tv1->split(0, 8); - fusion.print(); - auto all_ids = tv1->domain()->allIDs(); auto split2 = tv1->axis(0)->definition()->as(); diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index cb44550b883..e97550309e1 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -11,9 +11,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -513,6 +513,26 @@ TEST_F(StreamTest, HostIrDefaultStream) { c10::cuda::getDefaultCUDAStream(0), c10::cuda::getCurrentCUDAStream(0)); } +TEST_F(StreamTest, HostIrGetCurrentStream) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + auto get_stream = IrBuilder::create(); + auto current_stream = get_stream->stream(); + auto other_stream = IrBuilder::create(); + hic->pushBackTopLevelExprs(get_stream); + hic->pushBackTopLevelExprs(IrBuilder::create(other_stream)); + hic->pushBackTopLevelExprs( + IrBuilder::create(current_stream)); + + auto cuda_stream = c10::cuda::getStreamFromPool(); + setCurrentCUDAStream(cuda_stream); + + HostIrEvaluator hie(std::move(hic)); + hie.runWithInput({}); + + EXPECT_EQ(cuda_stream, c10::cuda::getCurrentCUDAStream(0)); +} + TEST_F(StreamTest, ByIndex) { constexpr int64_t kStreamIndex1 = 2; constexpr int64_t kStreamIndex2 = 3; diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index d9436979ba5..3691babd5b0 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -347,6 +348,13 @@ class PredicateIndexValidator : public kir::IrVisitor { auto out_ti = expr->output(0)->as(); + // This is just an initialization expr, likely by zero. Only the + // actual expr will be validted. + if (out_ti->view()->definition()->input(0)->isA() && + expr->input(0)->isScalar()) { + return; + } + NVF_ERROR(!scope_exprs_.empty()); auto inline_ite = dynamic_cast(scope_exprs_.back()); NVF_ERROR( @@ -5390,6 +5398,105 @@ TEST_F(IndexingTest, ResizeRotation) { testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); } +TEST_F(PredicateIndexingTest, VectorizedResizeRotation) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int64_t i0 = 32; + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto zero = fusion.zeroVal(); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeContigConcreteTensor({i0}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + // left half + auto tv2 = slice(tv1, {{zero, IrBuilder::create(i0 / 2)}}); + + auto tv3 = set(tv0); + // right half + auto tv4 = slice( + tv3, {{IrBuilder::create(i0 / 2), IrBuilder::create(i0)}}); + + // Rotation + auto tv5 = cat({tv4, tv2}, 0); + + auto tv6 = add(tv0, tv5); + + fusion.addOutput(tv6); + + for (Expr* expr : fusion.exprs()) { + if (expr->isOneOf()) { + scheduler_tools::propagateResizeToInputs(expr); + } + } + + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + + tv->split(0, 4); + } + + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + + inlineMost(); + + struct GetReference : AbstractGetReference { + GetReference(const TensorIndexer& indexer, const IdModel& id_model) + : AbstractGetReference(indexer, id_model) {} + + Val* getInlinePredicate(TensorView* tv) const override { + if (tv->name() != 1) { + return nullptr; + } + + if (for_loops_.back()->iter_domain()->getParallelType() != + ParallelType::Vectorize) { + return nullptr; + } + + std::vector loop_indices = getLoopIndices(tv, indexer_, for_loops_); + + Val* zero = tv->fusion()->zeroVal(); + + auto second_resize = dynamic_cast( + tv->axis(0)->definition()->input(0)->definition()); + EXPECT_NE(second_resize, nullptr); + + auto start_idx = addExpr( + IrBuilder::addExpr( + mulExpr(loop_indices.at(0), tv->axis(1)->extent()), zero), + IrBuilder::negExpr(second_resize->leftExpand())); + auto stop_idx = addExpr( + IrBuilder::addExpr( + mulExpr(loop_indices.at(0), tv->axis(1)->extent()), createInt(3)), + IrBuilder::negExpr(second_resize->leftExpand())); + + return andExpr( + geExpr(start_idx, tv->fusion()->zeroVal()), + ltExpr(stop_idx, tv->getLogicalDomain().at(0)->extent())); + } + }; + + PredicateIndexValidator::validate(&fusion, false); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({i0}, options); + std::vector inputs{t0}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + // Repro of issue #3505. The indexing WAR for resize triggered an // assertion due to loop promotion. TEST_F(IndexingTest, Issue3505Repro1) { diff --git a/tests/cpp/test_loop_domain_scheduling.cpp b/tests/cpp/test_loop_domain_scheduling.cpp index 107e0081eee..66901cc0790 100644 --- a/tests/cpp/test_loop_domain_scheduling.cpp +++ b/tests/cpp/test_loop_domain_scheduling.cpp @@ -448,4 +448,43 @@ TEST_F(LoopDomainSchedulingTest, ScheduleLoopDomainsBy2) { checkGetAllStmts(&fusion); } +// Make sure existing exprs should not be reused if +// update_loop_domain_only is true +TEST_F(LoopDomainSchedulingTest, UpdateLoopDomainOnlyWithExistingExpr) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = reshape(tv1, {IrBuilder::create(-1L)}); + fusion.addOutput(tv2); + + auto reshape_merge = + dynamic_cast(tv2->getLogicalDomain().at(0)->definition()); + ASSERT_NE(reshape_merge, nullptr); + + // Cancel the tv2 reshape + scheduler_tools::scheduleLoopDomainsLike({tv2}, tv1->getLoopDomain()); + + // Schedule tv1 + tv1->flatten(); + + // Propagate the tv1 schedule to tv2 + scheduler_tools::scheduleLoopDomainsLike( + {tv2}, + tv1->getLoopDomain(), + /*update_loop_domain_only=*/true); + + // The merge of tv1, which is propagated to tv2, is exact mapped + // with the merge for the tv2 reshape. It should not be reused as + // the update_loop_domain_only flag is true. + auto propagated_merge = + dynamic_cast(tv2->getLoopDomain().at(0)->definition()); + ASSERT_NE(propagated_merge, nullptr); + + EXPECT_NE(reshape_merge, propagated_merge); +} + } // namespace nvfuser diff --git a/tests/cpp/test_loop_rotation.cpp b/tests/cpp/test_loop_rotation.cpp index db5f3e20848..4b5122a66fd 100644 --- a/tests/cpp/test_loop_rotation.cpp +++ b/tests/cpp/test_loop_rotation.cpp @@ -307,7 +307,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor nvfuser_index_t i0; i0 = 4LL * T0.alloc_stride[0LL]; float T1[15LL]; - #pragma unroll + #pragma unroll 4 for(nvfuser_index_t i1 = 0LL; i1 < 4LL; ++i1) { nvfuser_index_t i2; i2 = 3LL * i1; @@ -335,7 +335,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor = T1[i6]; } NVFUSER_UPDATE_MAGIC_ZERO; - #pragma unroll 5 + #pragma unroll 4 for(nvfuser_index_t i7 = 0LL; i7 < T0.logical_size[0LL]; ++i7) { nvfuser_index_t i8; i8 = 4LL + i7; @@ -433,7 +433,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor = T0[(T0.alloc_stride[1LL] * (i3 + nvfuser_zero))]; } NVFUSER_UPDATE_MAGIC_ZERO; - #pragma unroll + #pragma unroll 4 for(nvfuser_index_t i4 = 0LL; i4 < 4LL; ++i4) { nvfuser_index_t i5; i5 = 3LL + (3LL * i4); @@ -474,7 +474,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor = T1[i8]; } NVFUSER_UPDATE_MAGIC_ZERO; - #pragma unroll 5 + #pragma unroll 4 for(nvfuser_index_t i9 = 0LL; i9 < T0.logical_size[0LL]; ++i9) { nvfuser_index_t i10; i10 = 3LL * i9; @@ -572,7 +572,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor i0 = toSmem(T4); float* ptr1; ptr1 = T0.data + (4LL * T0.alloc_stride[0LL]); - #pragma unroll + #pragma unroll 4 for(nvfuser_index_t i2 = 0LL; i2 < 4LL; ++i2) { float* ptr3; ptr3 = T0.data + (T0.alloc_stride[0LL] * i2); @@ -602,7 +602,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor float T1[2LL]; T1[0LL] = T4[0LL]; - #pragma unroll 5 + #pragma unroll 4 for(nvfuser_index_t i7 = 0LL; i7 < T0.logical_size[0LL]; ++i7) { float* ptr8; ptr8 = ptr1 + (T0.alloc_stride[0LL] * i7); @@ -633,7 +633,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor } NVFUSER_UPDATE_MAGIC_ZERO; asm volatile("cp.async.commit_group;\n"); - #pragma unroll + #pragma unroll 1 for(nvfuser_index_t i14 = 0LL; i14 < 2LL; ++i14) { T1[((1LL + i14) % 2LL)] = T4[(i11 + i14)]; diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index cbd51d97cfb..9e9395c5e18 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3657,6 +3657,7 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { const auto dtype = DataType::Half; constexpr bool use_smem_epilogue = false; + constexpr bool use_warp_specialization = true; constexpr int64_t stages = 4; constexpr int64_t prefetch = 3; @@ -3800,8 +3801,13 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { inlineMost(); - tv0c->circularBuffer(stages, prefetch); - tv1c->circularBuffer(stages, prefetch); + if (use_warp_specialization) { + tv0c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy)); + tv1c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy)); + } else { + tv0c->circularBuffer(stages, prefetch); + tv1c->circularBuffer(stages, prefetch); + } auto inputs = matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); @@ -3993,4 +3999,248 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } +TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M + auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {0}); + + // Reorder the accumulator as [M, N, K] + // [K, M, N] -> [M, N, K] + tv2->reorder({{-3, -1}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({K, M, 1}, options); + auto b_ref = at::randn({K, 1, N}, options); + auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + auto cg_outputs = ke.run(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + +TEST_F(HopperMatmulTest, HSH_TN_UseScheduler) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, K + auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // N, K + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({M, 1, K}, options); + auto b_ref = at::randn({1, N, K}, options); + auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze().t()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + auto cg_outputs = ke.run(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + +TEST_F(HopperMatmulTest, HSH_NN_UseScheduler) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, M + auto tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype); // N, K + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {1}); + + // Reorder the accumulator as [M, N, K] + // [M, K, N] -> [M, N, K] + tv2->reorder({{-1, -3}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({1, K, M}, options); + auto b_ref = at::randn({N, K, 1}, options); + auto out_ref = + at::matmul(a_ref.squeeze().t(), b_ref.squeeze().t()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + auto cg_outputs = ke.run(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + +TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // M, K + auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {1}); + + // Reorder the accumulator as [M, N, K] + // [M, K, N] -> [M, N, K] + tv2->reorder({{-2, -1}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({M, K, 1}, options); + auto b_ref = at::randn({1, K, N}, options); + auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + auto cg_outputs = ke.run(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + } // namespace nvfuser diff --git a/tests/cpp/test_matmul_sass.cpp b/tests/cpp/test_matmul_sass.cpp index d332c504a18..81b7566b0e1 100644 --- a/tests/cpp/test_matmul_sass.cpp +++ b/tests/cpp/test_matmul_sass.cpp @@ -372,9 +372,9 @@ TEST_F(MatmulSASSTest, AmpereModifiersSharedMemoryEpilogue) { bool found_LDGDEPBAR = false; bool found_DEPBAR = false; // kAllSupportedMmaLayout; int BAR_COUNT = 0; - // we have at least 6 shared memory barriers in the kernel if - // use_shared_epilogue. If promote_prologue_smem_reuse, then 8 - const int EXPECTED_BAR_COUNT = promote_prologue_smem_reuse ? 8 : 6; + // we have at least 5 shared memory barriers in the kernel if + // use_shared_epilogue. If promote_prologue_smem_reuse, then 7 + const int EXPECTED_BAR_COUNT = promote_prologue_smem_reuse ? 7 : 5; sass::Container sass; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 0ffde4364c1..8182464cf40 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -1060,7 +1060,7 @@ TEST_F(MatmulSchedulerTest, FusedMultiplySumOnly) { // for Ampere with strict ref check, hence single layout check TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT) { // TODO: Make these tests work with Hopper as well as Ampere - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 8, 9); + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); const int M = 128, N = 256, K = 512; const auto layout = MmaLayout::TT; @@ -2481,7 +2481,7 @@ class MatmulSchedulerPluginTest : public NVFuserTest { // Test that our fake plugin works to override the default heuristic TEST_F(MatmulSchedulerPluginTest, BasicMatmul) { - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 8, 9); + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); const int M = 128, N = 256, K = 512; const auto layout = MmaLayout::TT; auto fusion = std::make_unique(); @@ -2660,7 +2660,7 @@ TEST_F(MatmulSchedulerTest, SegmentMatmulOpUnsupportedDtype) { testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); } -TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) { +TEST_F(MatmulSchedulerTest, PreBroadcastMmaBiasNeg) { // TODO: fix up params or switch to FusionExecutorCache when ready, then // enable Ampere NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); @@ -2671,12 +2671,20 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) { // A - tv0, B - tv1 auto tv0 = makeContigConcreteTensor({-1, 1, -1}, DataType::Half); auto tv1 = makeContigConcreteTensor({1, -1, -1}, DataType::Half); + TensorView* tv2 = makeContigConcreteTensor({-1}, DataType::Half); fusion->addInput(tv0); fusion->addInput(tv1); + fusion->addInput(tv2); - auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); + auto tv3 = fusedMultiplySum(tv0, tv1, {-1}); + // We add these computations to test + // scheduling (with epilogue) when the ouptut of mma is not + // cast to half. + auto tv4 = maybeCastOp(DataType::Float, tv2); + auto tv5 = biasEpilogue(tv3, tv4); + auto tv6 = neg(tv5); - fusion->addOutput(tv2); + fusion->addOutput(tv6); NVF_CHECK( 1 == ir_utils::getOpsOfType(fusion.get()).size(), @@ -2689,10 +2697,14 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) { auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); auto a = at::randn({M, K}, options); auto b = at::randn({N, K}, options); + auto c = at::randn({M}, options); auto t0 = a.unsqueeze(1); auto t1 = b.unsqueeze(0); - auto tref = at::matmul(a.to(at::kFloat), b.to(at::kFloat).t()); - std::vector inputs{t0, t1}; + auto tref = + atBiasEpilogue( + at::matmul(a.to(at::kFloat), b.to(at::kFloat).t()), c.to(at::kFloat)) + .neg_(); + std::vector inputs{t0, t1, c}; MatmulParams mparams; mparams.supported_vec_size = {8, 8, 4}; @@ -2705,9 +2717,7 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) { mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = true; mparams.circular_buffer_options.smem_circular_buffer_stage = 2; - // TODO: Currently we use stmatrix whenever this is true. We cannot do that - // when the dtype is not 16 bits. - mparams.use_smem_epilogue = false; + mparams.use_smem_epilogue = true; mparams.promote_prologue_smem_reuse = false; SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) @@ -2720,17 +2730,23 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) { NVF_CHECK(outputs[0].allclose(tref, 0.001, 0.001)); } -class MatmulFusionTest : public MatmulSchedulerTest, - public ::testing::WithParamInterface { +class MatmulFusionTest + : public MatmulSchedulerTest, + public ::testing::WithParamInterface> { protected: void SetUp() override { if (fusion_enabled) { EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); } + if (horizontal_fusion_enabled) { + EnableOptionsGuard::getCurOptions().set( + EnableOption::FuseMultipleMatmuls); + } } EnableOptionsGuard eog_; - bool fusion_enabled = GetParam(); + bool fusion_enabled = GetParam().first; + bool horizontal_fusion_enabled = GetParam().second; }; // Test that we can segment a Fusion containing two matmuls @@ -2788,21 +2804,28 @@ TEST_P(MatmulFusionTest, Llama2FFN) { const FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - EXPECT_TRUE(runtime->isSegmented()); + size_t expected_kernels = + fusion_enabled ? (horizontal_fusion_enabled ? 1 : 2) : 3; - if (fusion_enabled) { - EXPECT_EQ(runtime->fusionSegments()->groups().size(), 2); - } else { - EXPECT_EQ(runtime->fusionSegments()->groups().size(), 3); - } + EXPECT_EQ(runtime->fusionSegments()->groups().size(), expected_kernels); } INSTANTIATE_TEST_SUITE_P( , MatmulFusionTest, - ::testing::Bool(), - [](const testing::TestParamInfo& info) { - return info.param ? "fuse" : "dontfuse"; + ::testing::ValuesIn(std::vector>{ + {false, false}, + {true, false}, + {true, true}}), + [](const testing::TestParamInfo>& info) { + bool fuse = info.param.first; + bool horiz_fuse = info.param.second; + if (horiz_fuse) { + NVF_ERROR( + fuse, "Horizontal fusion enabled but overall fusion disabled"); + } + return fuse ? (horiz_fuse ? "fuse_horizontal" : "fuse_single") + : "dontfuse"; }); // This test can be used to check that an external plugin has been loaded. It @@ -3143,7 +3166,7 @@ INSTANTIATE_TEST_SUITE_P( #undef NVFUSER_TEST_CUDA_ARCH_GUARD TEST_F(MatmulSchedulerTest, OperandOrderIssue2434) { - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 8, 9); + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0); int M = 32, N = 64, K = 128; std::unique_ptr fusion_ptr = std::make_unique(); @@ -3379,10 +3402,6 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) { // TODO: Remove this test once the architecture agnostic can be // run on hopper. TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) { - if (use_smem_epilogue) { - GTEST_SKIP() - << "TODO: We don't support smem epilogue in the Hopper matmul scheduler right now"; - } const auto& [A, B] = matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); const auto& C = matmulAtInput2D( diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index ce359b41d32..991fe732b72 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -2811,7 +2811,7 @@ TEST_P(LdMatrixTest, Regular) { // We get shapes M and N from MmaMacrao. The vector of ints are // the tile_m and tile_n factors (8x8, 16x8 and 16x16). -using StMatrixTestParams = std::tuple>; +using StMatrixTestParams = std::tuple, DataType>; class StMatrixTest : public NVFuserFixtureParamTest { protected: @@ -2829,6 +2829,7 @@ TEST_P(StMatrixTest, Regular) { auto macro = std::get<0>(GetParam()); auto tile_sizes = std::get<1>(GetParam()); + auto dtype = std::get<2>(GetParam()); auto sizeM = getM(macro); auto sizeN = getN(macro); int64_t tile_m = tile_sizes.at(0); @@ -2843,7 +2844,7 @@ TEST_P(StMatrixTest, Regular) { fusion.manage("st_matrix_m", sizeM); fusion.manage("st_matrix_n", sizeN); - auto tv0 = makeContigConcreteTensor({sizeM, sizeN}, DataType::Half); + auto tv0 = makeContigConcreteTensor({sizeM, sizeN}, dtype); fusion.addInput(tv0); // tv0 (global) -> tv1 (registers) auto tv1 = set(tv0); @@ -2859,19 +2860,24 @@ TEST_P(StMatrixTest, Regular) { tv0->split(0, 32); tv0->axis(1)->parallelize(ParallelType::TIDx); - auto s = - mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv1->getLoopDomain()); - tv1->setLoopDomain(s.as()); - tv1->setAllocationDomain(s.as(), true); + for (auto tv : {tv1, tv2}) { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv->getLoopDomain()); + tv->setLoopDomain(s.as()); + } + tv1->setAllocationDomain(tv1->getLoopDomain(), true); mma_utils::scheduleStMatrixForMmaOutput( tv2, /*swizzle=*/MmaInputSmemSwizzle::None, tile_m, tile_n); + tv2->axis(-1)->parallelize(ParallelType::Vectorize); + tv3->merge(0); tv3->split(0, 32); tv3->axis(1)->parallelize(ParallelType::TIDx); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); auto t0 = at::randn({sizeM, sizeN}, options); KernelExecutor ke; @@ -2886,13 +2892,14 @@ std::string testNameStMatrixTest( std::ostringstream os; auto macro = std::get<0>(info.param); auto tile_sizes = std::get<1>(info.param); + auto dtype = std::get<2>(info.param); auto sizeM = getM(macro); auto sizeN = getN(macro); auto tile_m = tile_sizes.at(0); auto tile_n = tile_sizes.at(1); os << "m_" << sizeM << "_n_" << sizeN << "_tile_m_" << tile_m << "_tile_n_" - << tile_n; + << tile_n << "_" << mma_utils::dtypeToChar(dtype); return os.str(); } @@ -2904,7 +2911,8 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( // tile_m, tile_n std::vector{16, 8}, - std::vector{16, 16})), + std::vector{16, 16}), + testing::Values(DataType::Half, DataType::BFloat16)), testNameStMatrixTest); TEST_P(LdMatrixTest, Transpose) { diff --git a/tests/cpp/test_mma.cpp b/tests/cpp/test_mma.cpp index 7e5ed33a8a6..7aafcafb8ab 100644 --- a/tests/cpp/test_mma.cpp +++ b/tests/cpp/test_mma.cpp @@ -405,7 +405,7 @@ using HopperMmaRSStMatrixTestParams = std::tuple< PrimDataType, MmaLayout, MmaInputSmemSwizzle, - std::vector>; + std::vector>; class HopperRSStmatrix : public HopperBase, @@ -415,7 +415,7 @@ class HopperRSStmatrix MmaMacro macro; PrimDataType dtype; MmaInputSmemSwizzle swizzle_b; - std::vector tile_sizes; + std::vector tile_sizes; void SetUp() override { HopperBase::SetUp(); @@ -434,8 +434,8 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) { auto shapes = matmulAtInputShape3DHopperRS( getM(macro), getN(macro), getK(macro), layout); - auto tile_m = tile_sizes.at(0); - auto tile_n = tile_sizes.at(1); + int64_t tile_m = tile_sizes.at(0); + int64_t tile_n = tile_sizes.at(1); if (getM(macro) % tile_m || getN(macro) % tile_n) { GTEST_SKIP() << "skipping test as output is not divisible by tile size"; @@ -515,12 +515,6 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) { EXPECT_TRUE(tv3->getMemoryType() == MemoryType::Shared); EXPECT_TRUE(tv4->getMemoryType() == MemoryType::Global); - { - auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - tv3c->getLoopDomain()); - tv3c->setLoopDomain(s.as()); - tv3c->setAllocationDomain(s.as(), true); - } { auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( tv2->getLoopDomain()); @@ -531,8 +525,26 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) { tv2->axis(-3)->parallelize(ParallelType::Mma); } + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv3c->getLoopDomain()); + tv3c->setLoopDomain(s.as()); + tv3c->setAllocationDomain(s.as(), true); + } + MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(tv3); + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv3->getLoopDomain()); + + if (swizzle != MmaInputSmemSwizzle::None) { + mma_utils::scheduleTMAStoreForMmaOutput(tv3, swizzle); + } + + tv3->setLoopDomain(s.as()); + } mma_utils::scheduleStMatrixForMmaOutput(tv3, swizzle, tile_m, tile_n); + tv3->axis(-1)->parallelize(ParallelType::Vectorize); mma_utils::scheduleTMAStoreForMmaOutput(tv4, swizzle); @@ -545,11 +557,12 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) { auto cg_outputs = ke.run({inputs.first, inputs.second}); auto tref = atMatmul( - inputs.first.squeeze().to(at::kFloat), - inputs.second.squeeze().to(at::kFloat), - layout); + inputs.first.squeeze().to(at::kFloat), + inputs.second.squeeze().to(at::kFloat), + layout) + .to(data_type_to_aten(dtype)); - EXPECT_TRUE(at::allclose(cg_outputs[0], tref.to(at::kHalf), 1e-1, 1e-1)); + EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-1, 1e-1)); } std::string testNameHopperRS( @@ -569,13 +582,13 @@ INSTANTIATE_TEST_SUITE_P( HopperRSStmatrix, testing::Combine( kAllHopperMacros, - testing::Values(DataType::Half), + testing::Values(DataType::Half, DataType::BFloat16), testing::Values(MmaLayout::TN, MmaLayout::TT), kAllSmemSwizzleModes, testing::Values( // M, N - std::vector{16, 8}, - std::vector{16, 16}))); + std::vector{16, 8}, + std::vector{16, 16}))); INSTANTIATE_TEST_SUITE_P( MmaTest, diff --git a/tests/cpp/test_move_pad.cpp b/tests/cpp/test_move_pad.cpp index 92ccaeae676..aa57dfec770 100644 --- a/tests/cpp/test_move_pad.cpp +++ b/tests/cpp/test_move_pad.cpp @@ -448,4 +448,36 @@ TEST_F(MovePadTest, BooleanCat) { __FILE__); } +TEST_F(MovePadTest, Issue3597Repro) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = slice( + tv0, + {{fusion.oneVal(), tv0->axis(0)->extent()}, + {fusion.zeroVal(), tv0->axis(1)->extent()}}); + auto tv3 = segment_set(tv2); + + auto tv4 = add(tv3, tv1); + auto tv5 = pad(tv4, {fusion.zeroVal(), fusion.oneVal()}); + auto tv6 = set(tv5); + fusion.addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({5, 10}, options); + auto t1 = at::randn({4, 10}, options); + std::vector inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index c93296aaa27..f5d8483de49 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -760,14 +760,8 @@ class AllgatherOverlapTest : public MultiDeviceTest { // This test implements an allgather-based pipelining overlapping technique, // similar to the above reduce-scattered based pipelining overlapping technique TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningATenImplementation) { - std::vector streams; - std::generate_n( - std::back_inserter(streams), - params.number_of_streams, - [my_device_index = my_device_index_]() { - return c10::cuda::getStreamFromPool( - /*isHighPriority=*/false, my_device_index); - }); + std::vector streams = + createStreams(params.number_of_streams, my_device_index_); for ([[maybe_unused]] const auto& _ : c10::irange(params.number_of_iterations)) { @@ -1060,9 +1054,9 @@ TEST_F( for ([[maybe_unused]] const auto& _ : c10::irange(params.number_of_iterations)) { initializeIO(); + c10::intrusive_ptr comms_req = nullptr; for (auto i : c10::irange(number_of_rings_)) { - c10::intrusive_ptr comms_req = nullptr; for (auto j : c10::irange(number_of_steps_per_ring_)) { int64_t stream_index = (i + j) % streams.size(); setCurrentCUDAStream(streams.at(stream_index)); @@ -1076,9 +1070,8 @@ TEST_F( auto ta_j_next_slice = ta_.select(0, next_slice_index).select(0, i); auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i); - if (comms_req != nullptr) { + if (j != 0) { comms_req->wait(); - comms_req = nullptr; } // send & matmul current index @@ -1098,4 +1091,181 @@ TEST_F( } } +TEST_F( + RingAllgatherOverlapTest, + RingAllgatherBasedPipeliningHostIRImplementation) { + auto hic = std::make_unique(); + FusionGuard::setCurFusion(hic.get()); + + TensorView* tva = makeSymbolicTensor(ta_.dim()); + TensorView* tvb_unsharded = makeSymbolicTensor(tb_unsharded_.dim()); + TensorView* tvc_unsharded = makeSymbolicTensor(tc_unsharded_.dim()); + hic->addInput(tva); + hic->addInput(tvb_unsharded); + hic->addInput(tvc_unsharded); + + auto* i = IrBuilder::create(DataType::Index); // for-loop running index + auto* start_i = hic->zeroVal(); + auto* stop_i = tva->axis(1)->extent(); + auto* step_i = hic->oneVal(); + auto* for_loop_i = IrBuilder::create( + /*IterDomain=*/tva->axis(1), + /*index=*/i, + start_i, + stop_i, + step_i, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + + auto* j = IrBuilder::create(DataType::Index); + auto* start_j = hic->zeroVal(); + auto* stop_j = tva->axis(0)->extent(); + auto* step_j = hic->oneVal(); + auto* for_loop_j = IrBuilder::create( + /*IterDomain=*/tva->axis(0), + /*index=*/j, + start_j, + stop_j, + step_j, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + + auto* stream_index = + mod(add(i, j), IrBuilder::create(params.number_of_streams)); + auto* set_stream = IrBuilder::create( + IrBuilder::create(stream_index)); + + auto* my_device_index_val = IrBuilder::create(my_device_index_); + auto* number_of_steps_per_ring_val = + IrBuilder::create(number_of_steps_per_ring_); + + auto* send_rank = mod( + add(my_device_index_val, hic->oneVal()), number_of_steps_per_ring_val); + auto* recv_rank = + mod(add(number_of_steps_per_ring_val, + sub(my_device_index_val, hic->oneVal())), + number_of_steps_per_ring_val); + + auto* slice_index = + mod(add(sub(my_device_index_val, j), number_of_steps_per_ring_val), + number_of_steps_per_ring_val); + auto* next_slice_index = + mod(add(sub(sub(my_device_index_val, j), hic->oneVal()), + number_of_steps_per_ring_val), + number_of_steps_per_ring_val); + + TensorView* tmp1 = select(tva, 0, slice_index); + TensorView* tmp2 = select(tva, 0, next_slice_index); + TensorView* tmp3 = select(tvc_unsharded, 0, slice_index); + TensorView* tva_j_curr_slice = select(tmp1, 0, i); + TensorView* tva_j_next_slice = select(tmp2, 0, i); + TensorView* tvc_j = select(tmp3, 0, i); + + auto* mm = + IrBuilder::create(tvc_j, tva_j_curr_slice, tvb_unsharded); + + // Setting the DeviceMesh of the communication's I/O is artificial but + // required at this point + DeviceMesh full_mesh(all_devices_); + tva_j_curr_slice->setDeviceMesh(full_mesh); + tva_j_next_slice->setDeviceMesh(full_mesh); + + auto* start_coalescing = IrBuilder::create(); + auto* send = IrBuilder::create( + P2PCommunicationType::SEND, tva_j_curr_slice, send_rank); + auto* recv = IrBuilder::create( + P2PCommunicationType::RECV, tva_j_next_slice, recv_rank); + auto* end_coalescing = IrBuilder::create(); + auto* wait = IrBuilder::create(end_coalescing); + + auto* cond = ne(j, hic->zeroVal()); + auto* wait_predicate = IrBuilder::create(cond); + auto* if_not_first_ring_step_wait = + IrBuilder::create(wait_predicate); + if_not_first_ring_step_wait->thenBody().push_back(wait); + + auto* comm_cond = ne(j, sub(stop_j, hic->oneVal())); + auto* comm_predicate = IrBuilder::create(comm_cond); + auto* if_not_last_ring_step_post_comms = + IrBuilder::create(comm_predicate); + if_not_last_ring_step_post_comms->thenBody().push_back(start_coalescing); + if_not_last_ring_step_post_comms->thenBody().push_back(send); + if_not_last_ring_step_post_comms->thenBody().push_back(recv); + if_not_last_ring_step_post_comms->thenBody().push_back(end_coalescing); + + std::vector loop_j_body = { + set_stream, + tmp1->definition(), + tmp2->definition(), + tmp3->definition(), + tva_j_curr_slice->definition(), + tva_j_next_slice->definition(), + tvc_j->definition(), + if_not_first_ring_step_wait, + if_not_last_ring_step_post_comms, + mm}; + for (Expr* expr : loop_j_body) { + for_loop_j->body().push_back(expr); + } + for_loop_i->body().push_back(for_loop_j); + + hic->pushBackTopLevelExprs(for_loop_i); + + // Synchronize all streams + auto* i_stream = + IrBuilder::create(DataType::Index); // running index of the for-loop + auto* start_stream = hic->zeroVal(); + auto* stop_stream = + IrBuilder::create(params.number_of_streams, DataType::Index); + auto* step_stream = hic->oneVal(); + auto* for_loop_stream = IrBuilder::create( + /*IterDomain=*/makeContigConcreteTensor({params.number_of_streams}) + ->axis(0), + /*index=*/i_stream, + start_stream, + stop_stream, + step_stream, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + auto* sync_stream = IrBuilder::create( + IrBuilder::create(i_stream)); + for_loop_stream->body().push_back(sync_stream); + hic->pushBackTopLevelExprs(for_loop_stream); + + hic->addOutput(tmp1); + hic->addOutput(tmp2); + hic->addOutput(tmp3); + hic->addOutput(tva_j_curr_slice); + hic->addOutput(tva_j_next_slice); + hic->addOutput(tvc_j); + + hir::HostIrEvaluator hie(std::move(hic), communicator_); + + for ([[maybe_unused]] const auto& _ : + c10::irange(params.number_of_iterations)) { + // I don't know why but this seems necessary... + at::manual_seed(getATenRandomSeed()); + + initializeIO(); + + std::unordered_map inputs = { + {tva, ta_}, + {tvb_unsharded, tb_unsharded_}, + {tvc_unsharded, tc_unsharded_}}; + + hie.runWithInput(std::move(inputs)); + + validate(); + } +} + } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index 3f1eac51462..5a626bfc967 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -64,7 +64,7 @@ class PipelineTest : public MultiDeviceTest { void PipelineTest::validate(bool validate_with_prescribed_values) { if (!validate_with_prescribed_values) { // execute the fusion on one device without pipeline scheduling - auto fusion_copy = std::make_unique(*runtime->completeFusion()); + auto fusion_copy = std::make_unique(*fusion); unshard(fusion_copy.get()); FusionExecutorCache unsharded_fec(std::move(fusion_copy)); ref_unsharded_outputs = unsharded_fec.runFusionWithInputs(unsharded_inputs); @@ -83,10 +83,9 @@ void PipelineTest::validate(bool validate_with_prescribed_values) { } ASSERT_EQ(ref_unsharded_outputs.size(), outputs.size()); - for (int i : c10::irange(runtime->completeFusion()->outputs().size())) { - ASSERT_TRUE(runtime->completeFusion()->outputs().at(i)->isA()); - auto output_tv = - runtime->completeFusion()->outputs().at(i)->as(); + for (int i : c10::irange(fusion->outputs().size())) { + ASSERT_TRUE(fusion->outputs().at(i)->isA()); + auto output_tv = fusion->outputs().at(i)->as(); if (!output_tv->getDeviceMesh().has(communicator_->deviceId())) { continue; } @@ -126,7 +125,9 @@ void PipelineTest::executeAndValidate(bool validate_with_prescribed_values) { } runtime = std::make_unique( - std::move(fusion), *communicator_, host_ir_executor_params); + std::make_unique(*fusion), + *communicator_, + host_ir_executor_params); auto error_msg = runtime->validate(); if (error_msg != "") { GTEST_SKIP() << error_msg; diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index aaa5d3a3218..5b93c119c66 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -479,9 +479,8 @@ TEST_P(MultiDeviceBroadcastTest, Expanded) { } FusionExecutorCache executor_cache(std::move(fusion)); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor in_tensor = - at::randn({8}, options) + at::randn({8}, tensor_options) .as_strided( {parallelizes_broadcast ? 3 : num_devices * 3, 8}, {0, 1}); at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 0f39ae6f6e5..6ccb217137f 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -13,25 +13,33 @@ #include #include #include +#include #include namespace nvfuser { -constexpr int64_t B = 2, E = 768, H = 16, S = 128; +namespace { +// Note: We test on smaller model and input sizes to avoid high error +// accumulation for validation. +static constexpr int64_t B = 2, E = 768, H = 16, S = 128; +// Note: Dropout probabilities are set to 0. Since the dropout mask is sharded +// it throws off the seed offset between the sharded nvFuser program and the +// unsharded reference. +static constexpr double kDropoutProb = 0.0, kSdpaProb = 0.0, kSdpaScale = 1e-3; // Note parameters scaled by kParamScale following weight initialization // recommendations: // https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2Config.initializer_range -// Note: Sdpa probability is set to 0. Since the dropout mask is sharded it -// throws off the seed offset between the sharded nvFuser program and the -// unsharded reference. -constexpr double kDropoutProb = 0.0, kParamScale = 0.02, kSdpaProb = 0.0, - kSdpaScale = 1e-3; +static constexpr double kParamScale = 0.02; +} // namespace class DistributedTransformerTest : public MultiDeviceTest, public testing::WithParamInterface { protected: - DistributedTransformerTest() : D(communicator_->size()) {} + DistributedTransformerTest() : D(communicator_->size()) { + model = std::make_unique( + D, B, E, H, S, kDropoutProb, kSdpaProb); + } void SetUp() override { MultiDeviceTest::SetUp(); @@ -41,6 +49,7 @@ class DistributedTransformerTest } const int64_t D; // number of devices + std::unique_ptr model; }; namespace { @@ -271,405 +280,6 @@ std::vector reference_mha_backwards( linear0}; return tensors; } - -struct MlpResult { - TensorView* linear0; - TensorView* gelu; - TensorView* matmul1; - TensorView* linear1; - TensorView* output; -}; - -MlpResult mlp( - TensorView* x, - TensorView* w0, - TensorView* b0, - TensorView* w1, - TensorView* b1, - const DeviceMesh& mesh, - bool sequence_parallel = false) { - const DataType dtype = w0->dtype(); - - if (sequence_parallel) { - // Input arrives sharded and must be allgathered back - x->setDeviceMesh(mesh); - x->axis(0)->parallelize(ParallelType::DIDx); - x = set(x); // allgather - x->axis(0)->parallelize(ParallelType::Serial); - // Reshape back to 2D. This is uncessary except to keep - // the shapes of linear0 the same for TP and TP+SP. - auto D = w0->axis(0)->extent()->value().as(); - x = reshape(x, {D, B * S / D, E}, {B * S, E}); - } - // Linear 0 - TensorView* linear0 = linear(x, w0, b0); - // GeLU - TensorView* gelu = tanh_gelu(castOp(DataType::Float, linear0)); - gelu = castOp(dtype, gelu); - // Linear 1 - TensorView* local_matmul1 = matmul(gelu, transpose(w1, 1, 2)); - if (sequence_parallel) { - // Remove after https://github.com/NVIDIA/Fuser/issues/2563 - // Reshape to explicitly pull the sharded axis into the logical domain - auto D = w0->axis(0)->extent()->value().as(); - local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E}); - } - TensorView* matmul1 = sum(local_matmul1, {0}); // Allreduce or Reduce scatter - std::vector bcast_mask(matmul1->nDims() - 1, true); - bcast_mask[matmul1->nDims() - 2] = false; - TensorView* linear1 = add(matmul1, broadcast(b1, bcast_mask)); - // Dropout - Val* prob = IrBuilder::create(1.0 - kDropoutProb); - Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); - TensorView* dropout_result = dropout(linear1, prob, scale).output; - - // Tensor parallel shardings - for (auto* tv : {w0, b0, w1}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - for (auto* tv : {x, b1}) { - tv->setDeviceMesh(mesh); - } - - // Sequence parallel shardings - if (sequence_parallel) { - matmul1->setDeviceMesh(mesh); - matmul1->axis(1)->parallelize(ParallelType::DIDx); - } - - return {linear0, gelu, matmul1, linear1, dropout_result}; -} - -struct MhaResult { - TensorView* linear0; - TensorView* sdpa; - TensorView* matmul1; - TensorView* linear1; - TensorView* output; -}; - -MhaResult mha( - TensorView* x, - TensorView* w0, - TensorView* b0, - TensorView* w1, - TensorView* b1, - const DeviceMesh& mesh, - bool sequence_parallel = false) { - const auto D = w0->axis(0)->extent()->value().as(); - auto dtype = w0->dtype(); - - if (sequence_parallel) { - // Input arrives sharded and must be allgathered back - x->setDeviceMesh(mesh); - x->axis(0)->parallelize(ParallelType::DIDx); - x = set(x); // allgather - x->axis(0)->parallelize(ParallelType::Serial); - // Reshape is uncessary, it is here to keep shapes with TP and TP+SP the - // same for validation. - x = reshape(x, {D, B * S / D, E}, {B * S, E}); - } - - TensorView* linear0 = linear(x, w0, b0); - // Forming the q,k,v vectors: - TensorView* qkv_cat = - reshape(linear0, {D, B * S, 3 * E / D}, {D, B, S, 3 * E / D}); - std::vector qkv = chunk(qkv_cat, 3, -1); - for (auto i : c10::irange(3)) { - qkv[i] = reshape(qkv[i], {D, B, S, E / D}, {D, B, S, H / D, E / H}); - qkv[i] = transpose(qkv[i], 2, 3); - } - // SDPA - SdpfaFwdResult sdpa = sdpfa_fwd( - qkv[0], - qkv[1], - qkv[2], - IrBuilder::create(kSdpaProb), - IrBuilder::create(true), - IrBuilder::create(kSdpaScale)); - TensorView* sdpa_output = sdpa.output; - // Linear 1 - TensorView* sdpa_transpose = transpose(sdpa_output, 2, 3); - TensorView* sdpa_reshape = - reshape(sdpa_transpose, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - TensorView* local_matmul1 = matmul(sdpa_reshape, transpose(w1, 1, 2)); - if (sequence_parallel) { - // Remove after https://github.com/NVIDIA/Fuser/issues/2563 - // Reshape to explicitly pull the sharded axis into the logical domain - auto D = w0->axis(0)->extent()->value().as(); - local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E}); - } - TensorView* matmul1 = sum(local_matmul1, {0}); // allreduce - std::vector bcast_mask(matmul1->nDims() - 1, true); - bcast_mask[matmul1->nDims() - 2] = false; - TensorView* linear1 = add(matmul1, broadcast(b1, bcast_mask)); - // Dropout - Val* prob = IrBuilder::create(1.0 - kDropoutProb); - Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); - TensorView* dropout_result = dropout(linear1, prob, scale).output; - - // Tensor parallel shardings - for (auto tv : {x, b1}) { - tv->setDeviceMesh(mesh); - } - for (auto tv : {w0, b0, w1}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - // Sequence parallel sharding. - if (sequence_parallel) { - matmul1->setDeviceMesh(mesh); - matmul1->axis(1)->parallelize(ParallelType::DIDx); - } - - return {linear0, sdpa_output, matmul1, linear1, dropout_result}; -} - -// TODO: These linear_backwards helper functions can be merged once -// we do not have logically split rfactor domain. -struct LinearBackwardsResult { - TensorView* grad_x; - TensorView* grad_w; - TensorView* grad_b; -}; - -// x format: [i0, i1] dtype -// weight format: [DID(D), i2/D, i1] dtype -// grad format: [DID(D) i0, i2/D] float or dtype -// outputs: grad_x [i0, i1] dtype -// grad_w [DID i2/D, i1] dtype -// grad_b [DID i2/2] dtype -LinearBackwardsResult linear_backwards( - TensorView* x, - TensorView* w, - TensorView* grad) { - DataType dtype = w->dtype(); - TensorView* grad_f = maybeCastOp(DataType::Float, grad); - TensorView* grad_q = maybeCastOp(dtype, grad); - TensorView* grad_x_partials = matmul(grad_q, w); - TensorView* grad_x = sum(grad_x_partials, {0}); // allreduce - TensorView* grad_q_t = transpose(grad_q, 1, 2); - TensorView* grad_w = matmul(grad_q_t, x); - TensorView* grad_b = sum(grad_f, {1}); - grad_b = castOp(dtype, grad_b); - - return {grad_x, grad_w, grad_b}; -} - -// x format: [DID, i0, i1/D] dtype -// weight format: [DID, i2, i1/D] dtype -// grad format: [i0, i2] float -// outputs: grad_x [DID i0, i1/D] dtype -// grad_w [DID, i2, i1/D] dtype -// grad_b [i2] dtype -LinearBackwardsResult sharded_linear_backwards( - TensorView* x, - TensorView* w, - TensorView* grad) { - DataType dtype = w->dtype(); - TensorView* grad_q = castOp(dtype, grad); - TensorView* grad_x = matmul(grad_q, w); - TensorView* grad_t = transpose(grad_q, 0, 1); - TensorView* grad_w = matmul(grad_t, x); - TensorView* grad_b = sum(grad, {0}); - grad_b = castOp(dtype, grad_b); - - return {grad_x, grad_w, grad_b}; -} - -// Forward layer_norm with cached mean_bcast and invstd tensors to avoid -// recomputing Welford. For use in backwards pass. -TensorView* layer_norm_with_cached_statistics( - TensorView* x, - TensorView* mean_bcast, - TensorView* invstd, - const std::vector& norm_shape, - TensorView* weight, - TensorView* bias) { - const int64_t kNumberOfDims = - (int64_t)TensorDomain::noReductions(x->getLogicalDomain()).size(); - const int64_t kOuterNumDims = kNumberOfDims - norm_shape.size(); - std::vector outer_broadcast_mask(kNumberOfDims, false); - for (const auto idx : c10::irange(kOuterNumDims)) { - outer_broadcast_mask[idx] = true; - } - - auto x_sub_mean = sub(x, mean_bcast); - auto y = mul(x_sub_mean, invstd); - - auto weight_bcast = broadcast(weight, outer_broadcast_mask); - y = mul(y, weight_bcast); - auto bias_bcast = broadcast(bias, outer_broadcast_mask); - return add(y, bias_bcast); -} - -// Backwards MLP block. -std::vector mlp_backwards( - TensorView* grad, - TensorView* x, - TensorView* mask, - TensorView* w0, - TensorView* w1, - TensorView* linear0, - const DeviceMesh& mesh) { - DataType dtype = w0->dtype(); - - // Activation recomputation: Always recompute gelu - TensorView* gelu = castOp(dtype, tanh_gelu(castOp(DataType::Float, linear0))); - - // Backwards pass - constexpr double kScale = 1.0 / (1.0 - kDropoutProb); - Val* dropout_scale = IrBuilder::create(kScale); - TensorView* dropout_grad = dropout_backward(grad, mask, dropout_scale); - auto linear1_grads = sharded_linear_backwards(gelu, w1, dropout_grad); - TensorView* matmul1_grad_x_ = castOp(DataType::Float, linear1_grads.grad_x); - TensorView* gelu_grad = tanh_gelu_backward(matmul1_grad_x_, linear0); - auto linear0_grads = linear_backwards(x, w0, gelu_grad); - - // Manaul sharding annotations - for (auto tv : - {x, - grad, - mask, - dropout_grad, - linear1_grads.grad_b, - linear0_grads.grad_x}) { - tv->setDeviceMesh(mesh); - } - - for (auto tv : - {w0, - w1, - linear0, - linear1_grads.grad_x, - linear1_grads.grad_w, - gelu_grad, - linear0_grads.grad_w, - linear0_grads.grad_b}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - - std::vector outputs = { - dropout_grad, - linear1_grads.grad_w, - linear1_grads.grad_b, - gelu_grad, - linear0_grads.grad_w, - linear0_grads.grad_b, - linear0_grads.grad_x}; - return outputs; -} - -std::vector mha_backwards( - TensorView* x, - TensorView* w0, - TensorView* w1, - TensorView* mask, - TensorView* sdpa_output, - TensorView* sdpa_log_sumexp, - TensorView* sdpa_seed, - TensorView* sdpa_offset, - TensorView* grad, - TensorView* linear0, - const DeviceMesh& mesh) { - DataType dtype = w0->dtype(); - const auto D = w0->axis(0)->extent()->value().as(); - // Reform qkv from linear0 output - TensorView* qkv_cat = reshape( - castOp(DataType::Float, linear0), - {D, B * S, 3 * E / D}, - {D, B, S, 3 * E / D}); - std::vector qkv = chunk(qkv_cat, 3, -1); - for (auto i : c10::irange(3)) { - qkv[i] = reshape(qkv[i], {D, B, S, E / D}, {D, B, S, H / D, E / H}); - qkv[i] = transpose(qkv[i], 2, 3); - qkv[i] = castOp(dtype, qkv[i]); - qkv[i]->setDeviceMesh(mesh); - qkv[i]->axis(0)->parallelize(ParallelType::DIDx); - } - - // dropout backwards - constexpr double kScale = 1.0 / (1.0 - kDropoutProb); - auto dropout_scale = IrBuilder::create(kScale); - TensorView* dropout_grad = dropout_backward(grad, mask, dropout_scale); - - // linear1 backwards - TensorView* sdpa_output_reshape = - transpose(sdpa_output, 2, 3); // D, B, S, H/D, E/H - sdpa_output_reshape = - reshape(sdpa_output_reshape, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - auto linear1_grads = - sharded_linear_backwards(sdpa_output_reshape, w1, dropout_grad); - - // SDPA backwards - TensorView* linear1_x_grad = - reshape(linear1_grads.grad_x, {D, B * S, E / D}, {D, B, S, H / D, E / H}); - linear1_x_grad = transpose(linear1_x_grad, 2, 3); // D, B, H/D, S, E/H - // Explicitly shard inputs before SDPA backward node - for (auto tv : {linear1_x_grad, sdpa_output, sdpa_log_sumexp}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - auto sdpa_grad = sdpfa_bwd( - linear1_x_grad, - qkv[0], - qkv[1], - qkv[2], - sdpa_output, - sdpa_log_sumexp, - /*dropout_p=*/IrBuilder::create(kSdpaProb), - /*is_causal=*/IrBuilder::create(true), - sdpa_seed, - sdpa_offset, - /*scale=*/IrBuilder::create(kSdpaScale)); - - TensorView* q_grad = transpose(sdpa_grad.grad_query, 2, 3); - q_grad = reshape(q_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - TensorView* v_grad = transpose(sdpa_grad.grad_value, 2, 3); - v_grad = reshape(v_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - TensorView* k_grad = transpose(sdpa_grad.grad_key, 2, 3); - k_grad = reshape(k_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - TensorView* kqv_grad = cat({k_grad, q_grad, v_grad}, -1); - auto linear0_grads = linear_backwards(x, w0, kqv_grad); - - for (auto tv : - {x, - mask, - grad, - dropout_grad, - linear1_grads.grad_b, - linear0_grads.grad_x}) { - tv->setDeviceMesh(mesh); - } - for (auto tv : - {w0, - w1, - sdpa_output, - sdpa_log_sumexp, - linear0, - linear1_grads.grad_x, - linear1_grads.grad_w, - linear0_grads.grad_w, - linear0_grads.grad_b, - sdpa_grad.grad_query, - sdpa_grad.grad_key, - sdpa_grad.grad_value}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - return { - dropout_grad, - linear1_grads.grad_w, - linear1_grads.grad_b, - sdpa_grad.grad_query, - sdpa_grad.grad_key, - sdpa_grad.grad_value, - linear0_grads.grad_w, - linear0_grads.grad_b, - linear0_grads.grad_x}; -} } // namespace TEST_P(DistributedTransformerTest, MLP_Layer) { @@ -695,7 +305,7 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { fusion->addInput(tvw1); fusion->addInput(tvb1); - auto tvsout = mlp(tvx, tvw0, tvb0, tvw1, tvb1, mesh); + auto tvsout = model->mlp(tvx, tvw0, tvb0, tvw1, tvb1, mesh); fusion->addOutput(tvsout.linear0); fusion->addOutput(tvsout.gelu); @@ -768,7 +378,7 @@ TEST_P(DistributedTransformerTest, Sequence_Parallel_MLP_Layer) { // Note only the sequence (S) dimension that is sharded // but to avoid DID parallelizations of inner logical axes // B*S is sharded. - auto tvsout = mlp(x, w0, b0, w1, b1, mesh, true); + auto tvsout = model->mlp(x, w0, b0, w1, b1, mesh, true); fusion->addInput(x); fusion->addInput(w0); @@ -842,7 +452,7 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) { fusion->addInput(tvw1); fusion->addInput(tvb1); - auto tv_outs = mha(tvx, tvw0, tvb0, tvw1, tvb1, mesh); + auto tv_outs = model->mha(tvx, tvw0, tvb0, tvw1, tvb1, mesh); fusion->addOutput(tv_outs.linear0); fusion->addOutput(tv_outs.sdpa); @@ -907,7 +517,7 @@ TEST_P(DistributedTransformerTest, MultiheadAttention_SP) { fusion->addInput(tvw1); fusion->addInput(tvb1); - auto tv_outs = mha(tvx, tvw0, tvb0, tvw1, tvb1, mesh, true); + auto tv_outs = model->mha(tvx, tvw0, tvb0, tvw1, tvb1, mesh, true); fusion->addOutput(tv_outs.linear0); fusion->addOutput(tv_outs.sdpa); @@ -974,7 +584,7 @@ TEST_P(DistributedTransformerTest, MLP_Backward) { fusion->addInput(linear0); std::vector tv_outs = - mlp_backwards(grad, x, mask, w0, w1, linear0, mesh); + model->mlp_backwards(grad, x, mask, w0, w1, linear0, mesh); for (TensorView* tv : tv_outs) { fusion->addOutput(tv); @@ -1056,7 +666,7 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { fusion->addInput(tvsdpa_offset); fusion->addInput(linear0); - auto tvouts = mha_backwards( + auto tvouts = model->mha_backwards( tvx, tvw0, tvw1, @@ -1134,74 +744,10 @@ TEST_P(DistributedTransformerTest, Forward_SP) { } auto dtype = GetParam(); at::ScalarType at_dtype = data_type_to_aten(dtype); - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); const auto mesh = DeviceMesh::createForNumDevices(D); - - TensorView* x = makeContigConcreteTensor({D, B * S / D, E}, dtype); - TensorView* ln0_w = makeContigTensor(1); - TensorView* ln0_b = makeContigTensor(1); - TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); - TensorView* mha_b0 = makeContigConcreteTensor({D, 3 * E / D}, dtype); - TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); - TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype); - TensorView* ln1_w = makeContigTensor(1); - TensorView* ln1_b = makeContigTensor(1); - TensorView* mlp_w0 = makeContigConcreteTensor({D, 4 * E / D, E}, dtype); - TensorView* mlp_b0 = makeContigConcreteTensor({D, 4 * E / D}, dtype); - TensorView* mlp_w1 = makeContigConcreteTensor({D, E, 4 * E / D}, dtype); - TensorView* mlp_b1 = makeContigConcreteTensor({E}, dtype); - - fusion->addInput(x); - fusion->addInput(ln0_w); - fusion->addInput(ln0_b); - fusion->addInput(mha_w0); - fusion->addInput(mha_b0); - fusion->addInput(mha_w1); - fusion->addInput(mha_b1); - fusion->addInput(ln1_w); - fusion->addInput(ln1_b); - fusion->addInput(mlp_w0); - fusion->addInput(mlp_b0); - fusion->addInput(mlp_w1); - fusion->addInput(mlp_b1); - constexpr float kEps = 1e-5; - auto eps = IrBuilder::create(kEps); std::vector norm_shape{E}; - auto ln_input = castOp(DataType::Float, x); - auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps); - auto mha_in = castOp(dtype, ln0.output); - auto mha_tvs = mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh, true); - auto resid0 = add(ln_input, mha_tvs.output); - auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps); - auto mlp_in = castOp(dtype, ln1.output); - auto mlp_tvs = mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh, true); - auto resid1 = add(resid0, mlp_tvs.output); - resid1 = castOp(dtype, resid1); - - fusion->addOutput(ln0.output); - fusion->addOutput(mha_tvs.output); - fusion->addOutput(ln1.output); - fusion->addOutput(mlp_tvs.output); - fusion->addOutput(resid1); - - x->setDeviceMesh(mesh); - x->axis(0)->parallelize(ParallelType::DIDx); - // Propagate SP shardings from x through layernorms, dropouts, residual adds. - // Even though mha_in is part of the boundary set, residuals allow the - // shardings to propagate up the graph so we must cut off the propagation at - // the outputs of reduce scatters (mha and mlp matmul1) - shardBetween({x}, {mha_in, mlp_in, mha_tvs.matmul1, mlp_tvs.matmul1}, x); - // Propagate TP sharding for MLP and MHA from sharded weights. We do not need - // to shard from mha_b0 or mlp_b0 because they are only consumed by their - // respective linear0 expression which is sharded from *_w0. - shardBetween({mha_w0}, {mha_tvs.matmul1}, mha_w0); - shardBetween({mha_w1}, {mha_tvs.matmul1}, mha_w1); - shardBetween({mlp_w0}, {mlp_tvs.matmul1}, mlp_w0); - shardBetween({mlp_w1}, {mlp_tvs.matmul1}, mlp_w1); - const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); auto x_ = at::randn({B * S, E}, options); @@ -1256,9 +802,9 @@ TEST_P(DistributedTransformerTest, Forward_SP) { shardTensor(mlp_out_, 0, mesh).unsqueeze(0), shardTensor(at_out, 0, mesh).unsqueeze(0)}; - FusionExecutorCache fec(std::move(fusion)); + auto fec = model->forward(dtype, true); at::manual_seed(getATenRandomSeed()); - auto outputs = fec.runFusionWithInputs(inputs); + auto outputs = fec->runFusionWithInputs(inputs); validate(expected_outputs, outputs, {1e-4, 0.02, 0.04, 0.04, 0.04}); } @@ -1269,67 +815,10 @@ TEST_P(DistributedTransformerTest, Forward) { } auto dtype = GetParam(); at::ScalarType at_dtype = data_type_to_aten(dtype); - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); const auto mesh = DeviceMesh::createForNumDevices(D); - - TensorView* x = makeContigConcreteTensor({B * S, E}, dtype); - TensorView* ln0_w = makeContigTensor(1); - TensorView* ln0_b = makeContigTensor(1); - TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); - TensorView* mha_b0 = makeContigConcreteTensor({D, 3 * E / D}, dtype); - TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); - TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype); - TensorView* ln1_w = makeContigTensor(1); - TensorView* ln1_b = makeContigTensor(1); - TensorView* mlp_w0 = makeContigTensor(3, dtype); - TensorView* mlp_b0 = makeContigTensor(2, dtype); - TensorView* mlp_w1 = makeContigTensor(3, dtype); - TensorView* mlp_b1 = makeContigTensor(1, dtype); - - fusion->addInput(x); - fusion->addInput(ln0_w); - fusion->addInput(ln0_b); - fusion->addInput(mha_w0); - fusion->addInput(mha_b0); - fusion->addInput(mha_w1); - fusion->addInput(mha_b1); - fusion->addInput(ln1_w); - fusion->addInput(ln1_b); - fusion->addInput(mlp_w0); - fusion->addInput(mlp_b0); - fusion->addInput(mlp_w1); - fusion->addInput(mlp_b1); - constexpr float kEps = 1e-5; - auto eps = IrBuilder::create(kEps); std::vector norm_shape{E}; - auto ln_input = castOp(DataType::Float, x); - auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps); - auto mha_in = castOp(dtype, ln0.output); - auto mha_out = mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh).output; - auto resid0 = add(ln_input, mha_out); - auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps); - auto mlp_in = castOp(dtype, ln1.output); - auto mlp_out = mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh).output; - auto resid1 = add(resid0, mlp_out); - resid1 = castOp(dtype, resid1); - - fusion->addOutput(ln0.output); - fusion->addOutput(mha_out); - fusion->addOutput(ln1.output); - fusion->addOutput(mlp_out); - fusion->addOutput(resid1); - - for (auto tv : {x, ln0.output, ln1.output, resid1}) { - tv->setDeviceMesh(mesh); - } - - shardBetween({mha_in->definition()}, {mha_out->definition()}, mha_w0); - shardBetween({mlp_in->definition()}, {mlp_out->definition()}, mlp_w0); - shardBetween({x}, {mha_in}, x); - const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); auto x_ = at::randn({B * S, E}, options); @@ -1380,9 +869,9 @@ TEST_P(DistributedTransformerTest, Forward) { std::vector expected_outputs = { ln0_out_, mha_out_, ln1_out_, mlp_out_, at_out}; - FusionExecutorCache executor_cache(std::move(fusion)); + auto executor_cache = model->forward(dtype); at::manual_seed(getATenRandomSeed()); - auto outputs = executor_cache.runFusionWithInputs(inputs); + auto outputs = executor_cache->runFusionWithInputs(inputs); validate(expected_outputs, outputs, {1e-4, 0.02, 0.04, 0.04, 0.04}); } @@ -1399,172 +888,6 @@ TEST_P(DistributedTransformerTest, Backward) { constexpr float kEps = 1e-5; std::vector norm_shape{E}; - TensorView* x = makeContigConcreteTensor({B * S, E}, dtype); - TensorView* grad = makeContigTensor(2, dtype); - TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); - TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); - TensorView* mlp_w0 = makeContigTensor(3, dtype); - TensorView* mlp_w1 = makeContigTensor(3, dtype); - TensorView* mha_mask = makeContigTensor(2, DataType::Bool); - TensorView* mlp_mask = makeContigTensor(2, DataType::Bool); - TensorView* mha_sdpa_out = makeConcreteTensor({D, B, H / D, S, E / H}, dtype); - TensorView* mha_sdpa_log_sumexp = - makeContigConcreteTensor({D, B, H / D, S}, DataType::Float); - TensorView* mha_sdpa_seed = makeSymbolicTensor({}, DataType::Int); - TensorView* mha_sdpa_offset = makeSymbolicTensor({}, DataType::Int); - TensorView* ln1_w = makeContigTensor(1); - TensorView* ln1_b = makeContigTensor(1); - TensorView* ln1_mean = makeConcreteTensor({B * S, 1}); - TensorView* ln1_rstd = makeConcreteTensor({B * S, 1}); - TensorView* ln0_w = makeContigTensor(1); - TensorView* ln0_b = makeContigTensor(1); - TensorView* ln0_mean = makeConcreteTensor({B * S, 1}); - TensorView* ln0_rstd = makeConcreteTensor({B * S, 1}); - TensorView* mha_linear0 = makeContigTensor(3, dtype); - TensorView* mha_linear1 = makeContigTensor(2); - TensorView* mlp_linear0 = makeContigTensor(3, dtype); - - fusion->addInput(x); - fusion->addInput(grad); - fusion->addInput(mha_w0); - fusion->addInput(mha_w1); - fusion->addInput(mlp_w0); - fusion->addInput(mlp_w1); - fusion->addInput(mlp_mask); - fusion->addInput(mha_mask); - fusion->addInput(mha_sdpa_out); - fusion->addInput(mha_sdpa_log_sumexp); - fusion->addInput(mha_sdpa_seed); - fusion->addInput(mha_sdpa_offset); - fusion->addInput(ln1_w); - fusion->addInput(ln1_b); - fusion->addInput(ln1_mean); - fusion->addInput(ln1_rstd); - fusion->addInput(ln0_w); - fusion->addInput(ln0_b); - fusion->addInput(ln0_mean); - fusion->addInput(ln0_rstd); - fusion->addInput(mha_linear0); - fusion->addInput(mha_linear1); - fusion->addInput(mlp_linear0); - - // Activation recomputation: mlp gelu, dropouts, and - // partially recompute layer norms using cached statistics. - auto ln0_in = castOp(DataType::Float, x); - auto ln0 = layer_norm_with_cached_statistics( - ln0_in, ln0_mean, ln0_rstd, norm_shape, ln0_w, ln0_b); - auto mha_in = castOp(dtype, ln0); - - Val* dropout_scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); - // Use input mha_mask to implement dropout - auto mha_out = mul(mha_linear1, mha_mask); - mha_out = mul(mha_out, dropout_scale); - auto resid0 = add(ln0_in, mha_out); - auto ln1 = layer_norm_with_cached_statistics( - resid0, ln1_mean, ln1_rstd, norm_shape, ln1_w, ln1_b); - auto mlp_in = castOp(dtype, ln1); - - // Backwards - auto grad_float = castOp(DataType::Float, grad); - auto mlp_grads = mlp_backwards( - grad_float, mlp_in, mlp_mask, mlp_w0, mlp_w1, mlp_linear0, mesh); - auto ln1_grads = layer_norm_backward( - castOp(DataType::Float, mlp_grads[6]), - resid0, - norm_shape, - ln1_mean, - ln1_rstd, - ln1_w, - ln1_b, - {true, true, true}); - auto resid1_grad = add(ln1_grads.grad_input, grad_float); - auto mha_grads = mha_backwards( - mha_in, - mha_w0, - mha_w1, - mha_mask, - mha_sdpa_out, - mha_sdpa_log_sumexp, - mha_sdpa_seed, - mha_sdpa_offset, - resid1_grad, - mha_linear0, - mesh); - auto ln0_grads = layer_norm_backward( - castOp(DataType::Float, mha_grads[8]), - ln0_in, - norm_shape, - ln0_mean, - ln0_rstd, - ln0_w, - ln0_b, - {true, true, true}); - auto dx = add(ln0_grads.grad_input, resid1_grad); - dx = castOp(dtype, dx); - - fusion->addOutput(mlp_grads[1]); // mlp linear1 weight grad - fusion->addOutput(mlp_grads[2]); // mlp linear1 bias grad - fusion->addOutput(mlp_grads[4]); // mlp linear0 weight grad - fusion->addOutput(mlp_grads[5]); // mlp linear0 bias grad - fusion->addOutput(ln1_grads.grad_weight); - fusion->addOutput(ln1_grads.grad_bias); - fusion->addOutput(mha_grads[1]); // mha linear1 weight grad - fusion->addOutput(mha_grads[2]); // mha linear1 bias grad - fusion->addOutput(mha_grads[6]); // mha linear0 weight grad - fusion->addOutput(mha_grads[7]); // mha linear0 bias grad - fusion->addOutput(ln0_grads.grad_weight); - fusion->addOutput(ln0_grads.grad_bias); - fusion->addOutput(dx); // transformer grad input - - // Sharding annotations for input and output TVs not sharded - // by mlp_backward or mha_backward - for (auto* tv : - {ln0_w, - ln0_b, - ln0_mean, - ln0_rstd, - ln1_w, - ln1_b, - ln1_mean, - ln1_rstd, - ln1_grads.grad_weight, - ln1_grads.grad_bias, - ln0_grads.grad_weight, - ln0_grads.grad_bias, - ln0_grads.grad_input}) { - tv->setDeviceMesh(mesh); - } - - // Sharded inputs to outputs - shardBetween( - {mha_w0, mha_w1, mha_sdpa_out}, - {mha_grads[1], mha_grads[6], mha_grads[7]}, - mha_w0); - shardBetween( - {mlp_w0, mlp_w1}, {mlp_grads[1], mlp_grads[4], mlp_grads[5]}, mlp_w0); - - // Unsharded inputs to outputs - shardBetween( - {x, - grad, - mha_mask, - mlp_mask, - mha_linear1, - ln0_mean, - ln0_w, - ln0_b, - ln1_mean, - ln1_w, - ln1_b}, - {mlp_grads[2], - ln1_grads.grad_weight, - ln1_grads.grad_bias, - mha_grads[2], - ln0_grads.grad_weight, - ln0_grads.grad_bias, - dx}, - x); - const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); auto x_ = at::randn({B * S, E}, options); @@ -1667,9 +990,9 @@ TEST_P(DistributedTransformerTest, Backward) { shardTensor(mlp_out_[0], 1, mesh).unsqueeze(0) // mlp linear1 }; - FusionExecutorCache executor_cache(std::move(fusion)); + auto executor_cache = model->backward(dtype); at::manual_seed(getATenRandomSeed()); - auto outputs = executor_cache.runFusionWithInputs(inputs); + auto outputs = executor_cache->runFusionWithInputs(inputs); validate( expected_outputs, outputs, diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index bb1c6bd7bfb..94e5c1f6d93 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -49,6 +50,15 @@ bool hasVectorizationCache(TensorView* tv) { return false; } +class DomainMapUnitTest : public scheduler_tools::DomainMap { + public: + DomainMapUnitTest(Fusion* fusion) : scheduler_tools::DomainMap(fusion) {}; + bool testTargetCoverage(TensorView* target_tv, TensorView* reference_tv) + const { + return areAllTargetIdsCoveredBy(target_tv, reference_tv); + } +}; + } // namespace TEST_F(PointwiseTest, VectorizeStrideContiguity2D) { @@ -306,7 +316,7 @@ TEST_F(PointwiseTest, Issue1567VectorizeAllocationDomain) { at::Tensor input1 = at::empty_strided({1, 128, 1}, {128, 1, 128}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -340,7 +350,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase0) { at::Tensor input1 = at::randn({1024, 2, 512}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs, false); auto pparams = cg_results.heuristic_params->as(); @@ -374,7 +384,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase1) { at::Tensor input1 = at::randn({1024, 512, 2}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -414,7 +424,7 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase2) { at::Tensor input1 = at::empty_strided({1024, 512, 2}, {2, 2048, 1}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -451,7 +461,7 @@ TEST_F(PointwiseTest, VIssue1567ectorizationFactorAnalysisCase3) { at::Tensor input1 = at::randn({512, 1024, 2}, options); std::vector aten_inputs = {input0, input1}; - // NOTE: force pointwise scheduler here just for testing purpose + // NOTE force pointwise scheduler here just for testing purpose auto cg_results = scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); auto pparams = cg_results.heuristic_params->as(); @@ -773,4 +783,503 @@ TEST_F(PointwiseTest, VectorizePadLoweringPermuted) { EXPECT_TRUE(found_vectorize); testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } + +TEST_F(PointwiseTest, DomainMapTestEg0) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i0, i1} + auto tv1 = relu(tv0); + fusion->addOutput(tv1); + // tv2 {i0, b2, i1} + auto tv2 = broadcast(tv1, {false, true, false}); + // tv3 {i0, b3{1 ex 4}, i1} + auto tv3 = expand( + tv2, + {tv2->axis(0)->extent(), + IrBuilder::create(4), + tv2->axis(2)->extent()}); + // NOTE hat currently expand doesn't introduce an iter domain operation, so + // we don't see that i4 is produced by realizing the expanded extent of b3{1 + // ex 4} tv4 {i0, i4*i1} + auto tv4 = reshape(tv3, {2, 4, 3}, {2, 12}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv4 is not covered by tv1, because the expanded ID i4 participates in + // transformation + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv1)); + + // tv3 is not covered by tv1, because the missing ID b3{1 ex 4} is concretized + // as i4, which is not mapped on tv1 + EXPECT_FALSE(domain_map.testTargetCoverage(tv3, tv1)); + + // tv1 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv1, tv4)); + + // tv1 is not a valid reference + EXPECT_FALSE(domain_map.isValidReference(tv1)); + + // tv4 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 7}, options); + std::vector aten_inputs = {t0}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapTestEg1) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i2, i0, i1} + TensorView* tv1 = makeContigTensor(3); + fusion->addInput(tv1); + // tv2 {i0*i1} + auto tv2 = reshape(tv0, {2, 4}, {8}); + fusion->addOutput(tv2); + + // tv3 {b3, i0, i1} + auto tv3 = broadcast(tv0, {true, false, false}); + // tv4 {i2, i0, i1} + auto tv4 = add(tv1, tv3); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv4 is not covered by tv2, because it misses i2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); + + // tv2 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + // tv2 is not a valid reference + EXPECT_FALSE(domain_map.isValidReference(tv2)); + + // tv4 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 4}, options); + at::Tensor t1 = at::randn({3, 2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapTestEg2) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i0, i1} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i0, i1} + auto tv1 = relu(tv0); + fusion->addOutput(tv1); + // tv2 {i0, b2, i1} + auto tv2 = broadcast(tv1, {false, true, false}); + // tv3 {i0, b3{1 ex 4}, i1} + auto tv3 = expand( + tv2, + {tv2->axis(0)->extent(), + IrBuilder::create(4), + tv2->axis(2)->extent()}); + fusion->addOutput(tv3); + + DomainMapUnitTest domain_map(fusion); + // tv3 is covered by tv1, because the missing ID b3{1 ex 4} is broadcast and + // doesn't get resolved to a concrete broadcast ID. + EXPECT_TRUE(domain_map.testTargetCoverage(tv3, tv1)); + + // tv1 is covered by tv4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv1, tv3)); + + // tv1 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv1)); + + // tv3 is a valid reference + EXPECT_TRUE(domain_map.isValidReference(tv3)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 7}, options); + std::vector aten_inputs = {t0}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapFactory) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv1 {i1} + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + // tv1 {i0, i1} + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv1); + + // tv2 {b2, b3, i1} + auto tv2 = broadcast(tv0, {true, true, false}); + // NOTE tv1 will be broadcasted to {b2, i0, i1} before the add. + // tv3 {b2, i0, i1} + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + auto size_val = IrBuilder::create(4.0, DataType::Int); + auto one_val = IrBuilder::create(1, DataType::Int); + // factory method creates an iter domain out of thin air + // tv4 {i4{4}, b4, i1} + auto tv4 = ones({size_val, one_val, tv0->axis(0)->extent()}, DataType::Float); + // tv5 {i4{4}, i0, i1} + auto tv5 = mul(tv2, tv4); + fusion->addOutput(tv5); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is not covered by tv3, because it's missing i4{4} + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv3)); + // tv1 is not covered by tv4, since it's missing i0 + EXPECT_FALSE(domain_map.testTargetCoverage(tv1, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv3)); + // tv5 has the same IDs as tv4, and is not a valid reference. + EXPECT_FALSE(domain_map.isValidReference(tv5)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::empty_strided({25}, {1}, options); + at::Tensor input1 = at::empty_strided({7, 25}, {25, 1}, options); + auto cg_outputs = executor_cache.runFusionWithInputs({input0, input1}); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + SegmentedFusion* segmented_fusion = runtime->fusionSegments(); + // This fusion currently cannot be scheduled as a single kernel. It is + // expected to be segmented as: g{(pointwise) + // inputs: tv0, tv1 + // outputs: tv2, tv3 + // tv2 = broadcast(tv0) + // tv3 = add (tv2, broadcast(tv1)) + // } + // + // g{(pointwise) + // inputs: tv2 + // outputs: tv5 + // tv4 = full({4, 1, i0}) + // tv5 = mul(tv2, tv4) + // } + EXPECT_EQ(segmented_fusion->groups().size(), 2); + + for (SegmentedGroup* group : segmented_fusion->groups()) { + const std::vector& exprs = group->exprs(); + + size_t num_full = std::count_if(exprs.begin(), exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + if (num_full != 0) { + // this is the segment contains the factory op. + EXPECT_EQ(exprs.size(), 2); + EXPECT_EQ(num_full, 1); + auto binary_op_iter = + std::find_if(exprs.begin(), exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + EXPECT_EQ( + (*binary_op_iter)->as()->getBinaryOpType(), + BinaryOpType::Mul); + Fusion* group_fusion = group->getFusion(); + // validate that we have a valid reference in the segmented fusion + DomainMapUnitTest group_dm(group_fusion); + EXPECT_EQ(group_fusion->outputs().size(), 1); + EXPECT_TRUE(group_dm.isValidReference( + group_fusion->outputs()[0]->as())); + } else { + // validate segmentation has the correct ops + EXPECT_EQ(exprs.size(), 3); + EXPECT_EQ( + std::count_if( + exprs.begin(), + exprs.end(), + [](Expr* expr) { return expr->isA(); }), + 2); + EXPECT_EQ( + std::count_if( + exprs.begin(), + exprs.end(), + [](Expr* expr) { return expr->isA(); }), + 1); + Fusion* group_fusion = group->getFusion(); + auto output_add = std::find_if( + group_fusion->outputs().begin(), + group_fusion->outputs().end(), + [](Val* val) { return val->definition()->isA(); }); + EXPECT_TRUE(output_add != group_fusion->outputs().end()); + DomainMapUnitTest group_dm(group_fusion); + // validate that the segmented fusion choose the add output as the + // reference + EXPECT_TRUE(group_dm.isValidReference((*output_add)->as())); + } + } + + testValidate(fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapPad0) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {b1, i0} + TensorView* tv0 = TensorViewBuilder().shape({1, -1}).build(); + fusion->addInput(tv0); + // tv1 {i2, b1, i0} + TensorView* tv1 = TensorViewBuilder().shape({-1, 1, -1}).build(); + fusion->addInput(tv1); + // tv2 {i2, b1, i0} + auto tv2 = add(tv1, tv0); + fusion->addOutput(tv2); + // i3 = resize(b1 + 4 + 4) + // tv3 {i3, i0} + auto tv3 = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(4L), + IrBuilder::create(4L)}); + // tv4 {i3*i0} + auto tv4 = reshape(tv3, {9, 5}, {45}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is covered by tv2, because i3 is produced by b1 + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv2)); + // tv2 is not covered by tv4, it's missing i2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv4)); + EXPECT_TRUE(domain_map.isValidReference(tv2)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); + at::Tensor t1 = at::empty_strided({7, 1, 5}, {5, 5, 1}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapPad1) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {b1, i0} + TensorView* tv0 = TensorViewBuilder().shape({1, -1}).build(); + fusion->addInput(tv0); + // tv1 {i2, i3, i4, b5} + TensorView* tv1 = TensorViewBuilder().shape({-1, -1, -1, 1}).build(); + fusion->addInput(tv1); + + // tv2 {b6, b7, b1, i0} + auto tv2 = broadcast(tv0, {true, true, false, false}); + // tv3 {i2, i3, i4, i0} + auto tv3 = add(tv1, tv2); + fusion->addOutput(tv3); + // i8 = resize(b1 + 4 + 4) + // tv4 {i8, i0} + auto tv4 = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(4L), + IrBuilder::create(4L)}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + + // tv4 is covered by tv3, because i8 is produced by b1, a broadcast dimension + // concretized as i4 + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv3)); + // tv3 is not covered by tv4, it's missing i2 and i3 + EXPECT_FALSE(domain_map.testTargetCoverage(tv3, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv4)); + EXPECT_TRUE(domain_map.isValidReference(tv3)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::empty_strided({1, 5}, {5, 1}, options); + at::Tensor t1 = at::empty_strided({2, 3, 4, 1}, {12, 4, 1, 1}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapSlice0) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i1, i0} + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + // tv1 {i1, i0} + // use concrete tensor to avoid need of concretization + TensorView* tv1 = makeContigConcreteTensor({2, 4}); + fusion->addInput(tv1); + + // b3 = resize(i0 + 0 - 3) + // tv2 {i1, b2} + auto tv2 = slice( + tv1, + {Slice(), + {IrBuilder::create(0L), + IrBuilder::create(1L), + IrBuilder::create(1L)}}); + fusion->addOutput(tv2); + // tv3 {i1, i0} + auto tv3 = add(tv0, tv1); + // tv4 {i1*i0} + auto tv4 = reshape(tv3, {2, 4}, {8}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // tv2 and tv4 has the same source IDs, since b3 = resize(i0 + 0 - 3) + EXPECT_TRUE(domain_map.testTargetCoverage(tv4, tv2)); + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_TRUE(domain_map.isValidReference(tv2)); + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 4}, options); + at::Tensor t1 = at::randn({2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, DomainMapSlice1) { + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // tv0 {i2, i1, i0} + TensorView* tv0 = makeContigTensor(3); + fusion->addInput(tv0); + // tv1 {i1, i0} + // use concrete tensor to avoid need of concretization + TensorView* tv1 = makeContigConcreteTensor({2, 4}); + fusion->addInput(tv1); + + // b3 = resize(i0 + 0 - 3) + // tv2 {i1, b3} + auto tv2 = slice( + tv1, + {Slice(), + {IrBuilder::create(0L), + IrBuilder::create(1L), + IrBuilder::create(1L)}}); + fusion->addOutput(tv2); + // tv3 {i2, i1, i0} + auto tv3 = add(tv0, tv1); + // tv4 {i2, i1*i0} + auto tv4 = reshape(tv3, {2, 2, 4}, {2, 8}); + fusion->addOutput(tv4); + + DomainMapUnitTest domain_map(fusion); + // i2 is missing in tv2 + EXPECT_FALSE(domain_map.testTargetCoverage(tv4, tv2)); + EXPECT_TRUE(domain_map.testTargetCoverage(tv2, tv4)); + + EXPECT_FALSE(domain_map.isValidReference(tv2)); + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + // validate generated kernel + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 2, 4}, options); + at::Tensor t1 = at::randn({2, 4}, options); + std::vector aten_inputs = {t0, t1}; + // NOTE force pointwise scheduler here for unit test + auto cg_results = + scheduleAndRun(fusion, SchedulerType::PointWise, aten_inputs); + testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, DomainMapBroadcastIssue3653) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto tv0 = makeConcreteTensor({2, 4, 8}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({2}); + fusion.addInput(tv1); + + auto tv2 = reshape(tv0, {2, 4, 8}, {2, 32}); + auto tv3 = broadcast(tv1, {false, true}); + auto tv4 = add(tv2, tv3); + + // tv4 covers source IDs {2, 4, 8}. + fusion.addOutput(tv4); + // meanwhile, tv3's broadcast ID map through permissive to `32`, which is not + // directly contained by tv4's source IDs. This test ensures that we project + // the mapped ID back to its source IDs and correctly schedule this fusion as + // a single kernel. + fusion.addOutput(tv3); + + DomainMapUnitTest domain_map(fusion_ptr.get()); + EXPECT_TRUE(domain_map.isValidReference(tv4)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 4, 8}, options); + auto t1 = at::randn({2}, options); + std::vector inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(!runtime->isSegmented()); + + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index 33fb1b635ba..4661d6e5599 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -770,4 +771,215 @@ TEST_F(PresegTest, DisjointSetsOfExtentsConcreteSymbolic) { testValidate( executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); } + +// Trivial repeat pattern +TEST_F(PresegTest, TranslateRepeatToExpand1) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, -1); + fusion.addOutput(tv1); + + { + // Make sure pad and cat no longer exist + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Consecutive repetitions with the same IDs +TEST_F(PresegTest, TranslateRepeatToExpand2) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, -1); + auto tv2 = cat({tv1, tv1}, -1); + + fusion.addOutput(tv2); + + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Consecutive repetitions with different IDs +TEST_F(PresegTest, TranslateRepeatToExpand3) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({4, 8}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, 1); + auto tv2 = cat({tv1, tv1}, 0); + + fusion.addOutput(tv2); + + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({4, 8}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Repeat the same ID of the same tensor multiple times. While the +// repetitions are the same, there's nothing to allow the output IDs +// to be mapped, so the translated fusion will be segmented. This is a +// downside compared to the original fusion, where all IDs are +// connected, so it's relatively straightforward to fuse them together +// without segmentation. +TEST_F(PresegTest, TranslateRepeatToExpand4) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({4, 8}); + fusion.addInput(tv0); + + // Consecutive repetitions with the same IDs + auto tv1 = cat({tv0, tv0}, 1); + auto tv2 = cat({tv0, tv0}, 1); + + fusion.addOutput(tv1); + fusion.addOutput(tv2); + + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({4, 8}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be segmented to two pointwise kernels + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + const auto& heuristic_list = runtime->schedulerHeuristics()->heuristicsList(); + ASSERT_EQ(heuristic_list.size(), 2); + EXPECT_EQ(heuristic_list.at(0)->scheduler_type, SchedulerType::PointWise); + EXPECT_EQ(heuristic_list.at(1)->scheduler_type, SchedulerType::PointWise); +} + +// Repeating more than two times +TEST_F(PresegTest, TranslateRepeatToExpand5) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0, tv0, tv0}, -1); + fusion.addOutput(tv1); + + { + // Make sure pad and cat no longer exist + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + } // namespace nvfuser::preseg_passes diff --git a/tests/cpp/test_resharding.cpp b/tests/cpp/test_resharding.cpp index a2479aafe8b..1757a8cd3ab 100644 --- a/tests/cpp/test_resharding.cpp +++ b/tests/cpp/test_resharding.cpp @@ -10,10 +10,10 @@ #include #include +#include #include #include #include -#include #include #include #include @@ -40,8 +40,7 @@ class ReshardingTest : public NVFuserFixtureParamTest { // FusionExecutorCache, simplify validation by using // FusionExecutorCache::getMostRecentKernelRuntime()->fusionSegments()->groups(). for (auto expr : fusion_->exprs()) { - EXPECT_TRUE(!isResharding(expr) || isLowerableToCommunication(expr)) - << "on expr=" << expr; + EXPECT_TRUE(HostIrLower::canLower(expr)) << "on expr: " << expr; } SegmentCandidateFinderOptions options{ diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 0b7e816cc46..587f72143a4 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -55,7 +55,27 @@ void checkLoopDomainEquivalence( } // namespace -using ResizeTest = NVFuserTest; +class ResizeTest : public NVFuserTest { + protected: + void SetUp() override { + EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); + NVFuserTest::SetUp(); + } + + private: + EnableOptionsGuard enable_options_guard_; +}; + +class ResizeSchedulerTest : public NVFuserFixtureParamTest { + protected: + void SetUp() override { + EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); + NVFuserFixtureParamTest::SetUp(); + } + + private: + EnableOptionsGuard enable_options_guard_; +}; using testing::Each; using testing::HasSubstr; @@ -64,6 +84,14 @@ using testing::Property; using testing::ThrowsMessage; using testing::UnorderedElementsAre; +INSTANTIATE_TEST_SUITE_P( + , + ResizeSchedulerTest, + testing::Bool(), + [](const testing::TestParamInfo& info) { + return info.param ? "Scheduler" : "Manual"; + }); + // Simple pad test TEST_F(ResizeTest, Pad1) { Fusion fusion; @@ -2055,7 +2083,10 @@ TEST_F(ResizeTest, ResizeReshapeAndSlice) { } // Make sure resize works with the transpose scheduler -TEST_F(ResizeTest, ResizePermuteAndSlice) { +// This is consumed by the resize scheduler. We should extend the +// transpose scheduler to support resize without the segment-input +// requirement. +TEST_F(ResizeTest, DISABLED_ResizePermuteAndSlice) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -2311,15 +2342,20 @@ TEST_F(ResizeTest, SliceVectorization) { constexpr int N = 1024 * 1024 * 64; - auto tv0 = makeContigConcreteTensor({N + 1}); + auto tv0 = makeContigConcreteTensor({N + 8}); fusion.addInput(tv0); auto tv1 = makeContigConcreteTensor({N}); fusion.addInput(tv1); + // Vectorization analysis is conservative. We considers the resize extent on + // both side. The slice here technically could have vectorization enabled, + // even when tv0 is sized as {N + 7}, which gives us resize extent `-3`. but + // the analysis doesn't support it at this time and requires resize extent to + // be vectorization friendly size. auto tv2 = slice( tv0, - {{IrBuilder::create(1L), - IrBuilder::create(N + 1L), + {{IrBuilder::create(4L), + IrBuilder::create(N + 4L), IrBuilder::create(1L)}}); auto tv3 = add(tv2, tv1); @@ -2327,7 +2363,7 @@ TEST_F(ResizeTest, SliceVectorization) { fusion.addOutput(tv3); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn(N + 1, options); + at::Tensor t0 = at::randn(N + 8, options); at::Tensor t1 = at::randn(N, options); std::vector inputs = {t0, t1}; @@ -2606,7 +2642,7 @@ TEST_F(ResizeTest, SliceAndReshape2) { } // Trivial case of slice vectorization. Just slicing a fusion input -TEST_F(ResizeTest, Slice1DVectorizeManual1) { +TEST_F(ResizeTest, Slice1DVectorize) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -2624,28 +2660,70 @@ TEST_F(ResizeTest, Slice1DVectorizeManual1) { sub(tv0->axis(0)->extent(), IrBuilder::create(slice_offset))}}); fusion.addOutput(tv1); - tv1->split(0, 4); - tv1->split(0, 128); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - tv1->axis(2)->parallelize(ParallelType::Vectorize); + auto cg_results = + scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs); + auto pparams = cg_results.heuristic_params->as(); + // check vectorization + ASSERT_EQ(pparams->vectorization_factor, 4) + << "Unexpected factor of vectorization"; + EXPECT_THAT( + tv1->getLoopDomain(), + Contains(Property(&IterDomain::getParallelType, ParallelType::Vectorize))) + << "Failed to vectorize: " << tv1; + + testValidate(&fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +// An input is sliced twice. Both should be vectorizable. +TEST_F(ResizeTest, Slice1DVectorize2) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + const int64_t slice_offset = 4; + const std::vector shape({1024L * 1024L}); + + auto tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + + // Following two slices are vectorized individually. No cache is introduced + auto tv1 = slice( + tv0, + {{IrBuilder::create(slice_offset), + sub(tv0->axis(0)->extent(), IrBuilder::create(slice_offset))}}); + fusion.addOutput(tv1); + + auto tv2 = slice( + tv0, + {{IrBuilder::create(slice_offset * 2), + sub(tv0->axis(0)->extent(), + IrBuilder::create(slice_offset * 2))}}); + fusion.addOutput(tv2); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); - KernelExecutor ke; - ke.compile(&fusion, aten_inputs); - auto cg_outputs = ke.run(aten_inputs); + auto cg_results = + scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs); + auto pparams = cg_results.heuristic_params->as(); + // check vectorization + ASSERT_EQ(pparams->vectorization_factor, 4) + << "Unexpected factor of vectorization"; + EXPECT_THAT( + tv1->getLoopDomain(), + Contains(Property(&IterDomain::getParallelType, ParallelType::Vectorize))) + << "Failed to vectorize: " << tv1; - auto ref = - t0.index({at::indexing::Slice(slice_offset, shape[0] - slice_offset)}); - ASSERT_TRUE(ref.equal(cg_outputs[0])); + testValidate(&fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } // An input is sliced twice. Both should be vectorizable. -TEST_F(ResizeTest, Slice1DVectorizeManual2) { +TEST_F(ResizeTest, Slice1DVectorize2Manual) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -2701,7 +2779,46 @@ TEST_F(ResizeTest, Slice1DVectorizeManual2) { } // An input is sliced and also entirely read. Both should be vectorizable. -TEST_F(ResizeTest, Slice1DVectorizeManual3) { +TEST_F(ResizeTest, Slice1DVectorize3) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + const int64_t slice_offset = 4; + const std::vector shape({1024L * 1024L}); + + auto tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = slice( + tv0, + {{IrBuilder::create(slice_offset), + sub(tv0->axis(0)->extent(), IrBuilder::create(slice_offset))}}); + fusion.addOutput(tv1); + + auto tv2 = set(tv0); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + auto cg_results = + scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs); + auto pparams = cg_results.heuristic_params->as(); + // check vectorization + ASSERT_EQ(pparams->vectorization_factor, 4) + << "Unexpected factor of vectorization"; + EXPECT_THAT( + tv1->getLoopDomain(), + Contains(Property(&IterDomain::getParallelType, ParallelType::Vectorize))) + << "Failed to vectorize: " << tv1; + + testValidate(&fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + +// An input is sliced and also entirely read. Both should be vectorizable. +TEST_F(ResizeTest, Slice1DVectorize3Manual) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -2749,6 +2866,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual3) { ASSERT_TRUE(t0.equal(cg_outputs.at(1))); } +// TODO: this is a case not yet supported by vectorization analysis // Vectorizing a slice of [1:-3]. It's vectorizable as long as the // offset at 1 is aligned TEST_F(ResizeTest, Slice1DVectorizeManual4) { @@ -2788,7 +2906,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual4) { } // Contig merged vectorization with slice -TEST_F(ResizeTest, Slice2DVectorizeManual1) { +TEST_F(ResizeTest, Slice2DVectorize1) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -2810,36 +2928,32 @@ TEST_F(ResizeTest, Slice2DVectorizeManual1) { {IrBuilder::create(0), tv0->axis(1)->extent()}}); fusion.addOutput(tv1); - tv1->merge(0); - tv1->split(0, 4); - tv1->split(0, 128); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - tv1->axis(2)->parallelize(ParallelType::Vectorize); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); - KernelExecutor ke; - ke.compile(&fusion, aten_inputs); - auto cg_outputs = ke.run(aten_inputs); + auto cg_results = + scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs); + auto pparams = cg_results.heuristic_params->as(); + // check vectorization + ASSERT_EQ(pparams->vectorization_factor, 4) + << "Unexpected factor of vectorization"; + EXPECT_THAT( + tv1->getLoopDomain(), + Contains(Property(&IterDomain::getParallelType, ParallelType::Vectorize))) + << "Failed to vectorize: " << tv1; - auto ref = t0.index( - {at::indexing::Slice(slice_offset, shape[0] - slice_offset), - at::indexing::Slice(0, at::indexing::None)}); - ASSERT_TRUE(ref.equal(cg_outputs.at(0))); + testValidate(&fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } // Fully contiguous tensor, but a sliced domain makes the domain to -// the left non-contiguous -TEST_F(ResizeTest, Slice3DVectorizeManual1) { +// the left non-contiguous, hence we need to check for its stride +TEST_F(ResizeTest, Slice3DVectorize1) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); - const std::vector shape({4, 1025, 3}); + const std::vector shape({1024, 1025, 3}); auto tv0 = makeContigConcreteTensor(shape); fusion.addInput(tv0); @@ -2847,48 +2961,32 @@ TEST_F(ResizeTest, Slice3DVectorizeManual1) { auto tv1 = slice( tv0, {{IrBuilder::create(0), tv0->axis(0)->extent()}, - {IrBuilder::create(4), IrBuilder::create(6)}, + {IrBuilder::create(4), IrBuilder::create(1024)}, {IrBuilder::create(0), tv0->axis(2)->extent()}}); fusion.addOutput(tv1); - // Vectorize tv1 by a factor of 2. The sliced domain and the - // innermost domain can be contiguous merged, thus producing a - // domain of extent 6, so vectorization by a factor of 2 appears to - // be valid, but due to the middle domain being sliced, the - // outermost domain is no longer contiguous, which means its stride - // must be divisible by 2, which is not the case here. - - // [4, 2, 3] - tv1->merge(1); - // [4, 6] - tv1->split(1, 2); - // [4, 3, 2] - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - tv1->axis(2)->parallelize(ParallelType::Vectorize); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); - KernelExecutor ke; - ke.compile(&fusion, aten_inputs); + auto cg_results = + scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs); + auto pparams = cg_results.heuristic_params->as(); - EXPECT_THAT( - [&]() { ke.run(aten_inputs); }, - ThrowsMessage( - HasSubstr("with word size 2 not possible due to invalid stride"))); + ASSERT_EQ(pparams->vectorization_factor, 1) + << "Unexpected factor of vectorization"; + + testValidate(&fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } -// Similar to Slice3DVectorizeManual2 but with a middle broadcast +// Similar to Slice3DVectorize2 but with a middle broadcast // domain -TEST_F(ResizeTest, Slice3DVectorizeManual2) { +TEST_F(ResizeTest, Slice3DVectorize2) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); - const std::vector shape({4, 1, 1025, 3}); + const std::vector shape({1024, 1, 1025, 3}); auto tv0 = makeContigConcreteTensor(shape); fusion.addInput(tv0); @@ -2901,27 +2999,18 @@ TEST_F(ResizeTest, Slice3DVectorizeManual2) { {IrBuilder::create(0), tv0->axis(3)->extent()}}); fusion.addOutput(tv1); - // [4, 1, 1024, 3] - tv1->merge(2); - // [4, 1, 3072] - tv1->split(2, 4); - // [4, 1, 768, 4] - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(2)->parallelize(ParallelType::TIDx); - tv1->axis(3)->parallelize(ParallelType::Vectorize); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); - KernelExecutor ke; - ke.compile(&fusion, aten_inputs); + auto cg_results = + scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs); + auto pparams = cg_results.heuristic_params->as(); + // check vectorization + ASSERT_EQ(pparams->vectorization_factor, 1) + << "Unexpected factor of vectorization"; - EXPECT_THAT( - [&]() { ke.run(aten_inputs); }, - ThrowsMessage( - HasSubstr("with word size 4 not possible due to invalid stride"))); + testValidate(&fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__); } // Repro of issue 540 without transpose @@ -3046,11 +3135,9 @@ TEST_F(ResizeTest, ReshapeToPad) { auto outputs = executor_cache.runFusionWithInputs(aten_inputs); - // Assert that we segmented into two segments auto seg_fusion = executor_cache.getMostRecentKernelRuntime()->fusionSegments(); - EXPECT_TRUE(seg_fusion->isSegmented()); - EXPECT_EQ(seg_fusion->groups().size(), 2); + EXPECT_EQ(seg_fusion->groups().size(), 1); testValidate( executor_cache.fusion(), @@ -3405,14 +3492,12 @@ TEST_F(ResizeTest, PadVectorization) { ASSERT_EQ(pparams->vectorization_factor, 4) << "Unexpected factor of vectorization"; - // Make sure tv1 is not vectorized, i.e., no loop IterDomains are vectorized. + // Make sure tv1/tv2 are vectorized, i.e., at least one loop IterDomain is + // vectorized. EXPECT_THAT( tv1->getLoopDomain(), Contains(Property(&IterDomain::getParallelType, ParallelType::Vectorize))) << "Failed to vectorize: " << tv1; - - // Make sure tv2 should be vectorized, i.e., at least one loop IterDomain is - // vectorized. EXPECT_THAT( tv2->getLoopDomain(), Contains(Property(&IterDomain::getParallelType, ParallelType::Vectorize))) @@ -3990,9 +4075,10 @@ TEST_F(ResizeTest, SliceSliceConcatConcat) { } // Consumer-based scheduling of slice -TEST_F(ResizeTest, PropagateSliceToInputs) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateSliceToInputs) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4002,64 +4088,85 @@ TEST_F(ResizeTest, PropagateSliceToInputs) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + // Dont't use set here as it gets taken by the no-op scheduler + auto tv1 = sin(tv0); auto tv2 = slice( tv1, {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, {IrBuilder::create(1L), IrBuilder::create(99)}}); - auto tv3 = set(tv2); + auto tv3 = cos(tv2); fusion.addOutput(tv3); - scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); - auto ref_tv = tv3; + const bool use_scheduler = GetParam(); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + if (!use_scheduler) { + scheduler_tools::propagateResizeToInputs(tv2->definition()); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + auto ref_tv = tv3; - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; - } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + inlineMost(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - std::vector inputs({t0}); + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // Propagating slice to inputs with reshape before slice -TEST_F(ResizeTest, PropagateSliceToInputsWithReshape1) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateSliceToInputsWithReshape1) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({16, 100}); @@ -4069,7 +4176,7 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape1) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = reshape(tv1, shape, {16, 5, 20}); @@ -4079,57 +4186,77 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape1) { {fusion.zeroVal(), tv2->getLogicalDomain().at(1)->extent()}, {IrBuilder::create(1L), IrBuilder::create(10)}}); - auto tv4 = set(tv3); + auto tv4 = cos(tv3); fusion.addOutput(tv4); - scheduler_tools::propagateResizeToInputs(tv3->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector inputs({t0}); - auto ref_tv = tv4; + const bool use_scheduler = GetParam(); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + if (!use_scheduler) { + scheduler_tools::propagateResizeToInputs(tv3->definition()); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + auto ref_tv = tv4; - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; - } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + inlineMost(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - std::vector inputs({t0}); + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // Propagating slice to inputs with reshape after slice -TEST_F(ResizeTest, PropagateSliceToInputsWithReshape2) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateSliceToInputsWithReshape2) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({16, 100}); @@ -4139,7 +4266,7 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape2) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = slice( tv1, @@ -4148,53 +4275,73 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape2) { auto tv3 = reshape(tv2, {shape[0], 49}, {shape[0] * 49}); - auto tv4 = set(tv3); + auto tv4 = cos(tv3); fusion.addOutput(tv4); - scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector inputs({t0}); - auto ref_tv = tv4; + const bool use_scheduler = GetParam(); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + if (!use_scheduler) { + scheduler_tools::propagateResizeToInputs(tv2->definition()); - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + auto ref_tv = tv4; - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; - } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + inlineMost(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - std::vector inputs({t0}); + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } -TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateMultipleSlicesToInputs1) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4204,7 +4351,7 @@ TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = slice( tv1, @@ -4216,69 +4363,96 @@ TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) { {{fusion.zeroVal(), tv2->getLogicalDomain().at(0)->extent()}, {IrBuilder::create(1L), tv2->getLogicalDomain().at(1)->extent()}}); - auto tv4 = set(tv3); + auto tv4 = cos(tv3); fusion.addOutput(tv4); - // Propagate the first slice to tv1 - scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); - // Propagate the second slice to tv1 and tv2 - scheduler_tools::propagateResizeToInputs(tv3->definition()); + const bool use_scheduler = GetParam(); - // Each of tv1 and tv2 has two resize ops. - for (auto tv : {tv1, tv2}) { - auto resize1 = dynamic_cast(tv->axis(-1)->definition()); - EXPECT_NE(resize1, nullptr); - auto resize2 = dynamic_cast(resize1->in()->definition()); - EXPECT_NE(resize2, nullptr) << tv->toString(); - } + if (!use_scheduler) { + // Propagate the first slice to tv1 + scheduler_tools::propagateResizeToInputs(tv2->definition()); - auto ref_tv = tv4; + // Propagate the second slice to tv1 and tv2 + scheduler_tools::propagateResizeToInputs(tv3->definition()); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Each of tv1 and tv2 has two resize ops. + for (auto tv : {tv1, tv2}) { + auto resize1 = dynamic_cast(tv->axis(-1)->definition()); + EXPECT_NE(resize1, nullptr); + auto resize2 = dynamic_cast(resize1->in()->definition()); + EXPECT_NE(resize2, nullptr) << tv->toString(); + } - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + auto ref_tv = tv4; - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - std::vector inputs({t0}); + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + // Make sure all slices are detected as exclusive + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + EXPECT_TRUE(non_exclusive_resize_info.empty()); - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } -// RoPE-like rotation patten -TEST_F(ResizeTest, SliceRotateCat) { - Fusion fusion; - FusionGuard fg(&fusion); +// Two horizontal slices, both of which slice the same iter domain. +TEST_F(ResizeSchedulerTest, PropagateMultipleSlicesToInputs2) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4288,92 +4462,534 @@ TEST_F(ResizeTest, SliceRotateCat) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = slice( tv1, {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, - {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); + {IrBuilder::create(1L), tv1->getLogicalDomain().at(1)->extent()}}); + + auto tv3 = sin(tv2); + + auto tv4 = sin(tv1); + + auto tv5 = slice( + tv4, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(2L), tv1->getLogicalDomain().at(1)->extent()}}); + + auto tv6 = sin(tv5); + + fusion.addOutput(tv3); + fusion.addOutput(tv6); + + { + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + + EXPECT_EQ(non_exclusive_resize_info.size(), 2); + + // tv2 is the output of the first slice, which is not exclusive as + // tv1 is also a producer of tv4. + EXPECT_EQ(non_exclusive_resize_info.count(tv2), 1); + scheduler_tools::ResizeExclusivityInfo tv2_info{ + {tv1}, exact_graph.toGroups(std::vector{tv1->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv2), tv2_info); + + // Similary, tv5 is the output of the second slice, which is not exclusive + // as tv1 is also a producer of tv2. + EXPECT_EQ(non_exclusive_resize_info.count(tv5), 1); + scheduler_tools::ResizeExclusivityInfo tv5_info{ + {tv1}, exact_graph.toGroups(std::vector{tv4->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv5), tv5_info); + } + + // Test replication-based mitigation of conflicts + { + Fusion fusion_copy = fusion; + FusionGuard fg(&fusion_copy); + + auto tv0 = fusion_copy.inputs().at(0)->as(); + auto tv2 = + fusion_copy.outputs().at(0)->definition()->input(0)->as(); + auto slice = dynamic_cast(tv2->definition()); + ASSERT_NE(slice, nullptr); + auto tv1 = slice->input(0)->as(); + auto tv5 = + fusion_copy.outputs().at(1)->definition()->input(0)->as(); + auto tv4 = tv5->definition()->input(0)->as(); + + // Replicate tv1 for tv2 + auto private_copy = RecomputeTv::recompute(tv1); + ir_utils::replaceValInExprInputs(slice, tv1, private_copy); + + // The two slices should still be reported as non-exclusive but they + // both are shared at the fusion input. + IdModel id_model(&fusion_copy, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion_copy), exact_graph); + EXPECT_EQ(non_exclusive_resize_info.size(), 2); + EXPECT_EQ(non_exclusive_resize_info.count(tv2), 1); + scheduler_tools::ResizeExclusivityInfo tv2_info{ + {tv0}, exact_graph.toGroups(std::vector{tv0->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv2), tv2_info); + + EXPECT_EQ(non_exclusive_resize_info.count(tv5), 1); + scheduler_tools::ResizeExclusivityInfo tv5_info{ + {tv0}, exact_graph.toGroups(std::vector{tv4->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv5), tv5_info); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + + // While the slices can be transformed to be all exclusive, it is + // currently segmented as the output has differet shapes. Both + // segments should be scheduled as resize segments. + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + const auto& heuristic_list = runtime->schedulerHeuristics()->heuristicsList(); + EXPECT_EQ(heuristic_list.size(), 2); + EXPECT_EQ(heuristic_list[0]->scheduler_type, SchedulerType::Resize); + EXPECT_EQ(heuristic_list[1]->scheduler_type, SchedulerType::Resize); +} - auto tv3 = set(tv0); +// Non-exclusive slice due to a dependency to a fusion output +TEST_F(ResizeSchedulerTest, PropagateMultipleSlicesToInputs3) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape({-1, 100}); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = makeConcreteTensor({-1}); + fusion.addInput(tv1); + + auto tv2 = sin(tv0); + + fusion.addOutput(tv2); + + auto tv3 = add(tv2, broadcast(tv1, {false, true})); auto tv4 = slice( tv3, {{fusion.zeroVal(), tv3->getLogicalDomain().at(0)->extent()}, - {IrBuilder::create(shape[1] / 2), - IrBuilder::create(shape[1])}}); + {IrBuilder::create(1L), tv3->getLogicalDomain().at(1)->extent()}}); - auto tv5 = cat({tv4, tv2}, 1); + auto tv5 = sin(tv4); fusion.addOutput(tv5); - // Propagate the left half of slice and pad - scheduler_tools::propagateResizeToInputs(tv2->definition()); - auto pad_left = - dynamic_cast(tv5->definition()->input(0)->definition()); - scheduler_tools::propagateResizeToInputs(pad_left); - - // Propagate the right half of slice and pad - scheduler_tools::propagateResizeToInputs(tv4->definition()); - auto pad_right = - dynamic_cast(tv5->definition()->input(1)->definition()); - scheduler_tools::propagateResizeToInputs(pad_right); + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); - auto ref_tv = tv5; + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); - - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + // tv4 is the input of the slice, which is not exclusive as + // tv3 depends on tv2, which is a fusion output + EXPECT_EQ(non_exclusive_resize_info.count(tv4), 1); + scheduler_tools::ResizeExclusivityInfo tv4_info{ + {tv2}, exact_graph.toGroups(std::vector{tv3->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv4), tv4_info); + // Test replication-based mitigation of conflicts { - IdModel id_model(&fusion, false); - id_model.buildExactGraph(); - std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); - auto dot_string = - id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); - ofs << dot_string; - ofs.close(); + Fusion fusion_copy = fusion; + FusionGuard fg(&fusion_copy); + + auto tv0 = fusion_copy.inputs().at(0)->as(); + auto tv5 = fusion_copy.outputs().at(1)->as(); + auto tv4 = tv5->definition()->input(0)->as(); + auto tv3 = tv4->definition()->input(0)->as(); + + auto private_copy = RecomputeTv::recompute(tv3); + ir_utils::replaceValInExprInputs(tv4->definition(), tv3, private_copy); + + IdModel id_model(&fusion_copy, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion_copy), exact_graph); + EXPECT_EQ(non_exclusive_resize_info.size(), 1); + EXPECT_EQ(non_exclusive_resize_info.count(tv4), 1); + scheduler_tools::ResizeExclusivityInfo tv4_info{ + {tv0}, exact_graph.toGroups(std::vector{tv0->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv4), tv4_info); } - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); + GTEST_SKIP() << "Scheduling not yet supported due to broadcast"; - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + auto t1 = at::randn({16}, options); + std::vector inputs({t0, t1}); - inlineMost(); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get())->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); +} - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; +// Slice input tensor depends on a fusion output, but the slice is +// still considered exclusive as the fusion output has no +// corresponding ID for the sliced ID. More specifically, tv2 is a +// fusion output and has a dependency to the input of the +// slice. However, the resize is done for the second axis of tv3, +// for which tv2 has no corresponding ID. In this case, it should be +// safe to do the propagation of the resize. +// +// Note that scheduling is not yet supported due to the existence of +// the dependency from the slice input ID to the broadcast ID. +TEST_F(ResizeSchedulerTest, PropagateMultipleSlicesToInputs4) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape({-1, 100}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = makeConcreteTensor({shape[0]}); + fusion.addInput(tv1); + + auto tv2 = sin(tv1); + + fusion.addOutput(tv2); + + auto tv3 = add(tv0, broadcast(tv2, {false, true})); + + auto tv4 = slice( + tv3, + {{fusion.zeroVal(), tv3->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(1L), tv3->getLogicalDomain().at(1)->extent()}}); + + auto tv5 = sin(tv4); + + fusion.addOutput(tv5); + + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + + EXPECT_TRUE(non_exclusive_resize_info.empty()); +} + +// Testing chained slices. Should be considered exclusive +TEST_P(ResizeSchedulerTest, PropagateMultipleSlicesToInputs5) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape({-1, 100}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(1L), tv1->getLogicalDomain().at(1)->extent()}}); + + auto tv3 = slice( + tv2, + {{fusion.zeroVal(), tv2->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(3L), tv2->getLogicalDomain().at(1)->extent()}}); + + auto tv4 = sin(tv3); + + fusion.addOutput(tv4); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); + + const bool use_scheduler = GetParam(); + + if (!use_scheduler) { + scheduler_tools::propagateResizeToInputs(tv2->definition()); + scheduler_tools::propagateResizeToInputs(tv3->definition()); + auto ref_tv = tv4; + + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); + + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); + + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + // The two slices do not conflict + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + EXPECT_TRUE(non_exclusive_resize_info.empty()); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); } +} + +// Testing chained slices. The first slice is considered +// non-exclusive, but the following slice should not. +TEST_F(ResizeSchedulerTest, PropagateMultipleSlicesToInputs6) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape({-1, 100}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(1L), tv1->getLogicalDomain().at(1)->extent()}}); + + auto tv3 = slice( + tv2, + {{fusion.zeroVal(), tv2->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(3L), tv2->getLogicalDomain().at(1)->extent()}}); + + auto tv4 = sin(tv3); + fusion.addOutput(tv4); - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + auto tv5 = sin(tv1); + fusion.addOutput(tv5); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn({16, 100}, options); std::vector inputs({t0}); - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + // The two slices do not conflict + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + EXPECT_EQ(non_exclusive_resize_info.size(), 1); + EXPECT_EQ(non_exclusive_resize_info.count(tv2), 1); + scheduler_tools::ResizeExclusivityInfo tv2_info{ + {tv1}, exact_graph.toGroups(std::vector{tv1->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv2), tv2_info); + + // When scheduled, since the shape of the tv4 is different from the + // shape of tv5, this fusion is segmented. One segment is a resize + // segment consisting of tv2 and tv3 slices. Another is a pointwise + // segment for tv5. + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + const auto& heuristic_list = runtime->schedulerHeuristics()->heuristicsList(); + EXPECT_EQ(heuristic_list.size(), 2); + // They should be a combination of a resize scheduler and a pointwise + // scheduler + EXPECT_TRUE( + (heuristic_list[0]->scheduler_type == SchedulerType::PointWise && + heuristic_list[1]->scheduler_type == SchedulerType::Resize) || + (heuristic_list[0]->scheduler_type == SchedulerType::Resize && + heuristic_list[1]->scheduler_type == SchedulerType::PointWise)); +} + +// RoPE-like rotation patten +TEST_P(ResizeSchedulerTest, SliceRotateCat) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape({-1, 100}); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); + + auto tv3 = sin(tv0); + + auto tv4 = slice( + tv3, + {{fusion.zeroVal(), tv3->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(shape[1] / 2), + IrBuilder::create(shape[1])}}); + + auto tv5 = cat({tv4, tv2}, 1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); + + fusion.addOutput(tv5); + + const bool use_scheduler = GetParam(); + + if (!use_scheduler) { + // Propagate the left half of slice and pad + scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto pad_left = + dynamic_cast(tv5->definition()->input(0)->definition()); + scheduler_tools::propagateResizeToInputs(pad_left); + + // Propagate the right half of slice and pad + scheduler_tools::propagateResizeToInputs(tv4->definition()); + auto pad_right = + dynamic_cast(tv5->definition()->input(1)->definition()); + scheduler_tools::propagateResizeToInputs(pad_right); + + auto ref_tv = tv5; + + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); + + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); + + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } + + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + // tv1 is not considered exclusive as tv0 is also a consumer of + // tv3. Same for tv3. While the common input, tv0, is a fusion + // input, so it isn't actually scheduled, since a cache is + // inserted, which is indeed scheduled, the two slices do + // conflict. + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + EXPECT_EQ(non_exclusive_resize_info.count(tv2), 1); + scheduler_tools::ResizeExclusivityInfo tv2_info{ + {tv0}, exact_graph.toGroups(std::vector{tv1->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv2), tv2_info); + EXPECT_EQ(non_exclusive_resize_info.count(tv4), 1); + scheduler_tools::ResizeExclusivityInfo tv4_info{ + {tv0}, exact_graph.toGroups(std::vector{tv3->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv4), tv4_info); + // These two entries should be all the info map has. + EXPECT_EQ(non_exclusive_resize_info.size(), 2); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // RoPE-like rotation and residual patten -TEST_F(ResizeTest, SliceRotateCatResidual) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, SliceRotateCatResidual) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4383,14 +4999,14 @@ TEST_F(ResizeTest, SliceRotateCatResidual) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = slice( tv1, {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); - auto tv3 = set(tv0); + auto tv3 = sin(tv0); auto tv4 = slice( tv3, @@ -4404,74 +5020,212 @@ TEST_F(ResizeTest, SliceRotateCatResidual) { fusion.addOutput(tv6); - // Propagate the left half of slice and pad - scheduler_tools::propagateResizeToInputs(tv2->definition()); - auto pad_left = - dynamic_cast(tv5->definition()->input(1)->definition()); - scheduler_tools::propagateResizeToInputs(pad_left); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); - // Propagate the right half of slice and pad - scheduler_tools::propagateResizeToInputs(tv4->definition()); - auto pad_right = - dynamic_cast(tv5->definition()->input(0)->definition()); - scheduler_tools::propagateResizeToInputs(pad_right); + const bool use_scheduler = GetParam(); + + if (!use_scheduler) { + // Propagate the left half of slice and pad + scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto pad_left = + dynamic_cast(tv5->definition()->input(1)->definition()); + scheduler_tools::propagateResizeToInputs(pad_left); + + // Propagate the right half of slice and pad + scheduler_tools::propagateResizeToInputs(tv4->definition()); + auto pad_right = + dynamic_cast(tv5->definition()->input(0)->definition()); + scheduler_tools::propagateResizeToInputs(pad_right); + + auto ref_tv = tv6; + + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); + + { + IdModel id_model(&fusion, false); + id_model.buildExactGraph(); + std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); + auto dot_string = + id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); + ofs << dot_string; + ofs.close(); + } - auto ref_tv = tv6; + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + inlineMost(); - { - IdModel id_model(&fusion, false); - id_model.buildExactGraph(); - std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); - auto dot_string = - id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); - ofs << dot_string; - ofs.close(); + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()) + << "Invalid computeAt position of " << tv->toString(); + } + + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + // tv1 is not considered exclusive as tv0 is also a consumer of + // tv3. Same for tv3. While the common input, tv0, is a fusion + // input, so it isn't actually scheduled, since a cache is + // inserted, which is indeed scheduled, the two slices do + // conflict. + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + EXPECT_EQ(non_exclusive_resize_info.count(tv2), 1); + scheduler_tools::ResizeExclusivityInfo tv2_info{ + {tv0}, exact_graph.toGroups(std::vector{tv1->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv2), tv2_info); + EXPECT_EQ(non_exclusive_resize_info.count(tv4), 1); + scheduler_tools::ResizeExclusivityInfo tv4_info{ + {tv0}, exact_graph.toGroups(std::vector{tv3->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv4), tv4_info); + // These two entries should be all the info map has. + EXPECT_EQ(non_exclusive_resize_info.size(), 2); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); } +} + +// Rotate twice. Resolving the non-exclusivity must be done in a +// topological order. +TEST_F(ResizeSchedulerTest, SliceRotateCatTwice) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); + std::vector shape({-1, 100}); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); - inlineMost(); + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; - } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()) - << "Invalid computeAt position of " << tv->toString(); - } + auto tv1 = sin(tv0); + + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); + + auto tv3 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(shape[1] / 2), + IrBuilder::create(shape[1])}}); - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + auto tv4 = cat({tv3, tv2}, -1); + + auto tv5 = slice( + tv4, + {{fusion.zeroVal(), tv4->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); + + auto tv6 = slice( + tv4, + {{fusion.zeroVal(), tv4->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(shape[1] / 2), + IrBuilder::create(shape[1])}}); + + auto tv7 = cat({tv6, tv5}, -1); + + fusion.addOutput(tv7); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn({16, 100}, options); std::vector inputs({t0}); - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + // tv1 is not considered exclusive as tv0 is also a consumer of + // tv3. Same for tv3. While the common input, tv0, is a fusion + // input, so it isn't actually scheduled, since a cache is + // inserted, which is indeed scheduled, the two slices do + // conflict. + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + + // tv2 + EXPECT_EQ(non_exclusive_resize_info.count(tv2), 1); + scheduler_tools::ResizeExclusivityInfo tv2_info{ + {tv1}, exact_graph.toGroups(std::vector{tv1->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv2), tv2_info); + + // tv3 + EXPECT_EQ(non_exclusive_resize_info.count(tv3), 1); + scheduler_tools::ResizeExclusivityInfo tv3_info{ + {tv1}, exact_graph.toGroups(std::vector{tv1->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv3), tv3_info); + + // tv5 + EXPECT_EQ(non_exclusive_resize_info.count(tv5), 1); + scheduler_tools::ResizeExclusivityInfo tv5_info{ + {tv4}, exact_graph.toGroups(std::vector{tv4->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv5), tv5_info); + + // tv6 + EXPECT_EQ(non_exclusive_resize_info.count(tv6), 1); + scheduler_tools::ResizeExclusivityInfo tv6_info{ + {tv4}, exact_graph.toGroups(std::vector{tv4->axis(1)})}; + EXPECT_EQ(non_exclusive_resize_info.at(tv6), tv6_info); + + // These should be all the info the map has. + EXPECT_EQ(non_exclusive_resize_info.size(), 4); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); } // Consumer-based scheduling of pad -TEST_F(ResizeTest, PropagatePadToInputs) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagatePadToInputs) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4481,61 +5235,87 @@ TEST_F(ResizeTest, PropagatePadToInputs) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = pad(tv1, {fusion.oneVal(), IrBuilder::create(2L)}); - auto tv3 = set(tv2); + auto tv3 = cos(tv2); fusion.addOutput(tv3); - scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); - auto ref_tv = tv3; + const bool use_scheduler = GetParam(); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + if (!use_scheduler) { + scheduler_tools::propagateResizeToInputs(tv2->definition()); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + auto ref_tv = tv3; - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - std::vector inputs({t0}); + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + EXPECT_TRUE(non_exclusive_resize_info.empty()); - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // Consumer-based scheduling of cat -TEST_F(ResizeTest, PropagateCatToInputs) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateCatToInputs) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4547,65 +5327,90 @@ TEST_F(ResizeTest, PropagateCatToInputs) { auto tv1 = makeConcreteTensor(shape); fusion.addInput(tv1); - auto tv2 = set(tv0); - auto tv3 = set(tv1); + auto tv2 = sin(tv0); + auto tv3 = sin(tv1); auto tv4 = cat({tv2, tv3}, -1); - auto tv5 = set(tv4); + auto tv5 = cos(tv4); fusion.addOutput(tv5); - // Propagate the pad op of each cat input - for (auto cat_inp : - ir_utils::filterByType(tv4->definition()->inputs())) { - auto pad_op = dynamic_cast(cat_inp->definition()); - ASSERT_NE(pad_op, nullptr); - scheduler_tools::propagateResizeToInputs(pad_op); - auto pad_inp = pad_op->input(0)->as(); - checkLoopDomainEquivalence(cat_inp, {pad_inp}); - } + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + auto t1 = at::randn({16, 100}, options); + std::vector inputs({t0, t1}); - auto ref_tv = tv4; + const bool use_scheduler = GetParam(); + + if (!use_scheduler) { + // Propagate the pad op of each cat input + for (auto cat_inp : + ir_utils::filterByType(tv4->definition()->inputs())) { + auto pad_op = dynamic_cast(cat_inp->definition()); + ASSERT_NE(pad_op, nullptr); + scheduler_tools::propagateResizeToInputs(pad_op); + auto pad_inp = pad_op->input(0)->as(); + checkLoopDomainEquivalence(cat_inp, {pad_inp}); + } - // At this point, all tensors should have the same loop domain - checkLoopDomainEquivalence(ref_tv); + auto ref_tv = tv4; - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + // At this point, all tensors should have the same loop domain + checkLoopDomainEquivalence(ref_tv); - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - inlineMost(); + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - auto t1 = at::randn({16, 100}, options); - std::vector inputs({t0, t1}); + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + auto non_exclusive_resize_info = scheduler_tools::getNonExclusiveResizeInfo( + ir_utils::getOpsOfType(&fusion), exact_graph); + EXPECT_TRUE(non_exclusive_resize_info.empty()); - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // manual scheduling that should have vectorized load on padded inputs. @@ -4837,53 +5642,6 @@ TEST_F(ResizeTest, VectorizePadNonInnermost) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -// padding with negative extent should prevent us considering the resize id for -// vectorization. So the example below should only have a vectorization factor -// of 2 -TEST_F(ResizeTest, VectorizePadNonInnermostNegativeExtent) { - Fusion fusion; - FusionGuard fg(&fusion); - - const std::vector shape({1024L, 1024L, 2L}); - - // Using a concrete tensor to avoid dynamic resize - auto tv0 = makeContigConcreteTensor(shape); - fusion.addInput(tv0); - - auto tv1 = - pad(tv0, - {IrBuilder::create(0L), - IrBuilder::create(0L), - IrBuilder::create(-4L), - IrBuilder::create(4L), - IrBuilder::create(0L), - IrBuilder::create(0L)}); - fusion.addOutput(tv1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - std::vector aten_inputs({t0}); - auto cg_outputs = - scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs).outputs; - - // check that we vectorize 4 - bool found_vectorize = false; - auto exprs = fusion.exprs(); - auto pad_ops = ir_utils::filterByType(exprs).vector(); - EXPECT_EQ(pad_ops.size(), 1); - EXPECT_TRUE(pad_ops.at(0)->out()->isA()); - for (auto id : pad_ops.at(0)->out()->as()->getLoopDomain()) { - if (id->getParallelType() == ParallelType::Vectorize) { - EXPECT_EQ(id->extent()->evaluate(), 2); - found_vectorize = true; - break; - } - } - EXPECT_TRUE(found_vectorize); - - testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); -} - TEST_F(ResizeTest, PadAndCacheUses) { Fusion fusion; FusionGuard fg(&fusion); diff --git a/tests/cpp/test_rope.cpp b/tests/cpp/test_rope.cpp new file mode 100644 index 00000000000..ddec92d58e9 --- /dev/null +++ b/tests/cpp/test_rope.cpp @@ -0,0 +1,897 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser { + +struct RopeConfig { + int64_t n_head = -1; + int64_t head_size = -1; + int64_t n_query_groups = -1; + int64_t rope_n_elem = -1; + int64_t batches = -1; + int64_t seq_length = -1; + + void verify() const { + ASSERT_EQ(n_head % n_query_groups, 0); + } + + std::string toString() const { + std::stringstream ss; + ss << "{n_head: " << n_head << ", head_size: " << head_size + << ", n_query_groups: " << n_query_groups + << ", rope_n_elem: " << rope_n_elem << ", batches: " << batches + << ", seq_length: " << seq_length << "}"; + return ss.str(); + } + + std::string toCompactString() const { + std::stringstream ss; + ss << n_head << "_" << head_size << "_" << n_query_groups << "_" + << rope_n_elem << "_" << batches << "_" << seq_length; + return ss.str(); + } +}; + +class RopeTest : public NVFuserFixtureParamTest { + protected: + void SetUp() override { + EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); + NVFuserTest::SetUp(); + } + + private: + EnableOptionsGuard enable_options_guard_; +}; + +using MistralRopeTest = RopeTest; + +INSTANTIATE_TEST_SUITE_P( + , + MistralRopeTest, + testing::Values(RopeConfig{ + /*n_head=*/32, + /*head_size=*/128, + /*n_query_groups=*/8, + /*rope_n_elem=*/128, + /*n_batches=*/1, + /*seq_length=*/4096}), + [](const testing::TestParamInfo& info) { + return info.param.toCompactString(); + }); + +// Mistral forward before matmul +// clang-format off +/* +def nvfuser_fusion_id0(fd : FusionDefinition) -> None : + T0 = fd.define_tensor(shape=[1, 4096, 1024], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) + T1 = fd.define_tensor(shape=[64], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) + T2 = fd.define_tensor(shape=[1, 4096], contiguity=[None, True], dtype=DataType.Int, is_cpu=False, stride_order=[1, 0]) + T8 = fd.ops.reshape(T0, new_shape=[1, 4096, 8, 128]) + T9 = fd.ops.permute(T8, dims=[0, 2, 1, 3]) + T14 = fd.ops.broadcast_in_dim(T1, shape=[1, 64, 1], broadcast_dims=[1]) + T15 = fd.ops.cast(T14, dtype=DataType.Float) + T20 = fd.ops.broadcast_in_dim(T15, shape=[1, 64, 1], broadcast_dims=[0, 1, 2]) + T25 = fd.ops.broadcast_in_dim(T2, shape=[1, 1, 4096], broadcast_dims=[0, 2]) + T26 = fd.ops.cast(T25, dtype=DataType.Float) + T33 = fd.ops.broadcast_in_dim(T9, shape=[1, 8, 1, 4096, 128], broadcast_dims=[0, 1, 3, 4]) + T40 = fd.ops.broadcast_in_dim(T33, shape=[1, 8, 4, 4096, 128], broadcast_dims=[0, 1, 2, 3, 4]) + T46 = fd.ops.reshape(T40, new_shape=[1, 32, 4096, 128]) + fd.add_output(T20) + fd.add_output(T26) + fd.add_output(T46) +*/ +// clang-format on +TEST_P(MistralRopeTest, Fwd1) { + const RopeConfig config = GetParam(); + config.verify(); + + const int64_t batch_size = config.batches; + const int64_t seq_len = config.seq_length; + const int64_t head_dim = config.head_size; + const int64_t num_attention_heads = config.n_head; + const int64_t num_key_value_heads = config.n_query_groups; + + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + std::vector shape1{ + batch_size, seq_len, head_dim * num_key_value_heads}; + std::vector shape2{head_dim / 2}; + std::vector shape3{batch_size, seq_len}; + + auto tv0 = makeContigConcreteTensor(shape1, DataType::BFloat16); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor(shape2, DataType::BFloat16); + fusion.addInput(tv1); + auto tv2 = makeContigConcreteTensor(shape3, DataType::Int); + fusion.addInput(tv2); + + // T3 + auto tv8 = reshape( + tv0, shape1, {batch_size, seq_len, num_key_value_heads, head_dim}); + // T4 + auto tv9 = permute(tv8, {0, 2, 1, 3}); + // T5 + auto tv14 = broadcast(tv1, {true, false, true}); + // T6 + auto tv15 = castOp(DataType::Float, tv14); + // T7. This is actually converted to just a set op + auto tv20 = expand( + tv15, + std::vector{ + IrBuilder::create(1L), + IrBuilder::create(head_dim / 2), + IrBuilder::create(1L)}); + // T8 + auto tv25 = broadcast(tv2, {false, true, false}); + // T9 + auto tv26 = castOp(DataType::Float, tv25); + // T10 + auto tv33 = broadcast(tv9, {false, false, true, false, false}); + // T11 + auto tv40 = expand( + tv33, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_key_value_heads), + IrBuilder::create(num_attention_heads / num_key_value_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + // T12 + auto tv46 = reshape( + tv40, + {batch_size, + num_key_value_heads, + num_attention_heads / num_key_value_heads, + seq_len, + head_dim}, + {batch_size, num_attention_heads, seq_len, head_dim}); + fusion.addOutput(tv20); + fusion.addOutput(tv26); + fusion.addOutput(tv46); + + auto options_float = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_bf16 = + at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options_bf16); + auto t1 = at::randn(shape2, options_bf16); + auto t2 = at::randn(shape3, options_float).to(at::kLong); + std::vector inputs({t0, t1, t2}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); +} + +// Mistral forward after matmul +// clang-format off +/* +def nvfuser_fusion_id1(fd : FusionDefinition) -> None : + T0 = fd.define_tensor(shape=[1, 4096, 4096], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) + T1 = fd.define_tensor(shape=[1, 4096, 1024], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0]) + T2 = fd.define_tensor(shape=[1, 64, 4096], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) + T8 = fd.ops.reshape(T0, new_shape=[1, 4096, 32, 128]) + T9 = fd.ops.permute(T8, dims=[0, 2, 1, 3]) + T15 = fd.ops.reshape(T1, new_shape=[1, 4096, 8, 128]) + T16 = fd.ops.permute(T15, dims=[0, 2, 1, 3]) + T17 = fd.ops.permute(T2, dims=[0, 2, 1]) + T18 = fd.ops.cat([T17, T17], dim=-1, manual_padding=0) + T19 = fd.ops.cos(T18) + T20 = fd.ops.sin(T18) + T21 = fd.ops.cast(T19, dtype=DataType.BFloat16) + T22 = fd.ops.cast(T20, dtype=DataType.BFloat16) + T28 = fd.ops.broadcast_in_dim(T21, shape=[1, 1, 4096, 128], broadcast_dims=[0, 2, 3]) + T34 = fd.ops.broadcast_in_dim(T22, shape=[1, 1, 4096, 128], broadcast_dims=[0, 2, 3]) + T40 = fd.ops.broadcast_in_dim(T28, shape=[1, 32, 4096, 128], broadcast_dims=[0, 1, 2, 3]) + T41 = fd.ops.cast(T9, dtype=DataType.Float) + T42 = fd.ops.cast(T40, dtype=DataType.Float) + T43 = fd.ops.mul(T41, T42) + T59 = fd.ops.slice(T9, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 4096, 64], strides=[1, 1, 1, 1], manual_normalization=0) + T75 = fd.ops.slice(T9, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 4096, 128], strides=[1, 1, 1, 1], manual_normalization=0) + T76 = fd.ops.cast(T75, dtype=DataType.Float) + T77 = fd.ops.neg(T76) + T78 = fd.ops.cast(T77, dtype=DataType.BFloat16) + T79 = fd.ops.cat([T78, T59], dim=-1, manual_padding=0) + T85 = fd.ops.broadcast_in_dim(T34, shape=[1, 32, 4096, 128], broadcast_dims=[0, 1, 2, 3]) + T86 = fd.ops.cast(T79, dtype=DataType.Float) + T87 = fd.ops.cast(T85, dtype=DataType.Float) + T88 = fd.ops.mul(T86, T87) + T89 = fd.ops.add(T43, T88) + T90 = fd.ops.cast(T89, dtype=DataType.BFloat16) + T96 = fd.ops.broadcast_in_dim(T28, shape=[1, 8, 4096, 128], broadcast_dims=[0, 1, 2, 3]) + T97 = fd.ops.cast(T16, dtype=DataType.Float) + T98 = fd.ops.cast(T96, dtype=DataType.Float) + T99 = fd.ops.mul(T97, T98) + T115 = fd.ops.slice(T16, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 4096, 64], strides=[1, 1, 1, 1], manual_normalization=0) + T131 = fd.ops.slice(T16, start_indices=[0, 0, 0, 64], end_indices=[1, 8, 4096, 128], strides=[1, 1, 1, 1], manual_normalization=0) + T132 = fd.ops.cast(T131, dtype=DataType.Float) + T133 = fd.ops.neg(T132) + T134 = fd.ops.cast(T133, dtype=DataType.BFloat16) + T135 = fd.ops.cat([T134, T115], dim=-1, manual_padding=0) + T141 = fd.ops.broadcast_in_dim(T34, shape=[1, 8, 4096, 128], broadcast_dims=[0, 1, 2, 3]) + T142 = fd.ops.cast(T135, dtype=DataType.Float) + T143 = fd.ops.cast(T141, dtype=DataType.Float) + T144 = fd.ops.mul(T142, T143) + T145 = fd.ops.add(T99, T144) + T146 = fd.ops.cast(T145, dtype=DataType.BFloat16) + T153 = fd.ops.broadcast_in_dim(T146, shape=[1, 8, 1, 4096, 128], broadcast_dims=[0, 1, 3, 4]) + T160 = fd.ops.broadcast_in_dim(T153, shape=[1, 8, 4, 4096, 128], broadcast_dims=[0, 1, 2, 3, 4]) + T166 = fd.ops.reshape(T160, new_shape=[1, 32, 4096, 128]) + fd.add_output(T90) + fd.add_output(T166) +*/ +// clang-format on +TEST_P(MistralRopeTest, Fwd2) { + const RopeConfig config = GetParam(); + config.verify(); + + const int64_t batch_size = config.batches; + const int64_t seq_len = config.seq_length; + const int64_t head_dim = config.head_size; + const int64_t num_attention_heads = config.n_head; + const int64_t num_key_value_heads = config.n_query_groups; + + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + std::vector shape1{ + batch_size, seq_len, head_dim * num_attention_heads}; + std::vector shape2{ + batch_size, seq_len, head_dim * num_key_value_heads}; + std::vector shape3{batch_size, head_dim / 2, seq_len}; + + auto tv0 = makeContigConcreteTensor(shape1, DataType::BFloat16); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor(shape2, DataType::BFloat16); + fusion.addInput(tv1); + auto tv2 = makeContigConcreteTensor(shape3, DataType::Float); + fusion.addInput(tv2); + + // T3 + auto tv8 = reshape( + tv0, shape1, {batch_size, seq_len, num_attention_heads, head_dim}); + // T4 + auto tv9 = permute(tv8, {0, 2, 1, 3}); + // T5 + auto tv15 = reshape( + tv1, shape2, {batch_size, seq_len, num_key_value_heads, head_dim}); + // T6 + auto tv16 = permute(tv15, {0, 2, 1, 3}); + // T7 + auto tv17 = permute(tv2, {0, 2, 1}); + // T8 = pad(T7) + // T9 = pad(T7) + // T10 + auto tv18 = cat({tv17, tv17}, -1); + // T11 + auto tv19 = cos(tv18); + // T12 + auto tv20 = sin(tv18); + // T13 + auto tv21 = castOp(DataType::BFloat16, tv19); + // T14 + auto tv22 = castOp(DataType::BFloat16, tv20); + // T15 + auto tv28 = broadcast(tv21, {false, true, false, false}); + // T16 + auto tv34 = broadcast(tv22, {false, true, false, false}); + // T17 + auto tv40 = expand( + tv28, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_attention_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + // T18 + auto tv41 = castOp(DataType::Float, tv9); + // T19 + auto tv42 = castOp(DataType::Float, tv40); + // T20 + auto tv43 = mul(tv41, tv42); + // T21 + auto tv59 = slice( + tv9, + {{fusion.zeroVal(), tv9->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), tv9->getLogicalDomain().at(1)->extent()}, + {fusion.zeroVal(), tv9->getLogicalDomain().at(2)->extent()}, + {fusion.zeroVal(), IrBuilder::create(head_dim / 2)}}); + // T22 + auto tv75 = slice( + tv9, + {{fusion.zeroVal(), tv9->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), tv9->getLogicalDomain().at(1)->extent()}, + {fusion.zeroVal(), tv9->getLogicalDomain().at(2)->extent()}, + {IrBuilder::create(head_dim / 2), + tv9->getLogicalDomain().at(3)->extent()}}); + // T23 + auto tv76 = castOp(DataType::Float, tv75); + // T24 + auto tv77 = neg(tv76); + // T25 + auto tv78 = castOp(DataType::BFloat16, tv77); + // T26 = pad(T25) + // T27 = pad(T21) + // T28 + auto tv79 = cat({tv78, tv59}, -1); + // T29 + auto tv85 = expand( + tv34, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_attention_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + // T30 + auto tv86 = castOp(DataType::Float, tv79); + // T31 + auto tv87 = castOp(DataType::Float, tv85); + // T32 + auto tv88 = mul(tv86, tv87); + // T33 + auto tv89 = add(tv43, tv88); + // T34 + auto tv90 = castOp(DataType::BFloat16, tv89); + + // T35 + auto tv96 = expand( + tv28, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_key_value_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + // T36 + auto tv97 = castOp(DataType::Float, tv16); + // T37 + auto tv98 = castOp(DataType::Float, tv96); + // T38 + auto tv99 = mul(tv97, tv98); + // T39 + auto tv115 = slice( + tv16, + {{fusion.zeroVal(), tv16->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), tv16->getLogicalDomain().at(1)->extent()}, + {fusion.zeroVal(), tv16->getLogicalDomain().at(2)->extent()}, + {fusion.zeroVal(), IrBuilder::create(head_dim / 2)}}); + // T40 + auto tv131 = slice( + tv16, + {{fusion.zeroVal(), tv16->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), tv16->getLogicalDomain().at(1)->extent()}, + {fusion.zeroVal(), tv16->getLogicalDomain().at(2)->extent()}, + {IrBuilder::create(head_dim / 2), + tv16->getLogicalDomain().at(3)->extent()}}); + // T41 + auto tv132 = castOp(DataType::Float, tv131); + // T42 + auto tv133 = neg(tv132); + // T43 + auto tv134 = castOp(DataType::BFloat16, tv133); + // T44 = pad(T43) + // T45 = pad(T39) + // T46 + auto tv135 = cat({tv134, tv115}, -1); + // T47 + auto tv141 = expand( + tv34, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_key_value_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + // T48 + auto tv142 = castOp(DataType::Float, tv135); + // T49 + auto tv143 = castOp(DataType::Float, tv141); + // T50 + auto tv144 = mul(tv142, tv143); + // T51 + auto tv145 = add(tv99, tv144); + // T52 + auto tv146 = castOp(DataType::BFloat16, tv145); + // T53 + auto tv153 = broadcast(tv146, {false, false, true, false, false}); + // T54 + auto tv160 = expand( + tv153, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_key_value_heads), + IrBuilder::create(num_attention_heads / num_key_value_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + // T55 + auto tv166 = reshape( + tv160, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_attention_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + + fusion.addOutput(tv90); + fusion.addOutput(tv166); + + auto options_fp32 = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_bf16 = + at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options_bf16); + auto t1 = at::randn(shape2, options_bf16); + auto t2 = at::randn(shape3, options_fp32); + std::vector inputs({t0, t1, t2}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); +} + +// clang-format off +/* +def nvfuser_fusion_id2(fd : FusionDefinition) -> None : + T0 = fd.define_tensor(shape=[1, 32, 4096, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0]) + T1 = fd.define_tensor(shape=[1, 64, 4096], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0]) + T2 = fd.define_tensor(shape=[1, 32, 4096, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0]) + T3 = fd.define_tensor(shape=[1, 32, 4096, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0]) + T10 = fd.ops.reshape(T0, new_shape=[1, 8, 4, 4096, 128]) + T11 = fd.ops.cast(T10, dtype=DataType.Float) + T12 = fd.ops.sum(T11, dims=[0, 2], keepdim=False, dtype=DataType.Null) + T13 = fd.ops.permute(T1, dims=[0, 2, 1]) + T14 = fd.ops.cast(T12, dtype=DataType.BFloat16) + T15 = fd.ops.cat([T13, T13], dim=-1, manual_padding=0) + T22 = fd.ops.broadcast_in_dim(T14, shape=[1, 8, 1, 4096, 128], broadcast_dims=[1, 3, 4]) + T23 = fd.ops.sin(T15) + T24 = fd.ops.cast(T22, dtype=DataType.Float) + T25 = fd.ops.cast(T23, dtype=DataType.BFloat16) + T26 = fd.ops.sum(T24, dims=[0, 2], keepdim=False, dtype=DataType.Null) + T32 = fd.ops.broadcast_in_dim(T25, shape=[1, 1, 4096, 128], broadcast_dims=[0, 2, 3]) + T33 = fd.ops.cast(T26, dtype=DataType.BFloat16) + T39 = fd.ops.broadcast_in_dim(T32, shape=[1, 32, 4096, 128], broadcast_dims=[0, 1, 2, 3]) + T45 = fd.ops.broadcast_in_dim(T33, shape=[1, 8, 4096, 128], broadcast_dims=[1, 2, 3]) + T51 = fd.ops.broadcast_in_dim(T32, shape=[1, 8, 4096, 128], broadcast_dims=[0, 1, 2, 3]) + T52 = fd.ops.cast(T2, dtype=DataType.Float) + T53 = fd.ops.cast(T39, dtype=DataType.Float) + T54 = fd.ops.cast(T45, dtype=DataType.Float) + T55 = fd.ops.cast(T51, dtype=DataType.Float) + T56 = fd.ops.mul(T53, T52) + T57 = fd.ops.mul(T55, T54) + T58 = fd.ops.cast(T56, dtype=DataType.BFloat16) + T59 = fd.ops.cast(T57, dtype=DataType.BFloat16) + T75 = fd.ops.slice(T58, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 4096, 64], strides=[1, 1, 1, 1], manual_normalization=0) + T91 = fd.ops.slice(T59, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 4096, 64], strides=[1, 1, 1, 1], manual_normalization=0) + T98 = fd.ops.reshape(T3, new_shape=[1, 8, 4, 4096, 128]) + T99 = fd.ops.cos(T15) + T100 = fd.ops.cast(T75, dtype=DataType.Float) + T101 = fd.ops.cast(T91, dtype=DataType.Float) + T102 = fd.ops.cast(T98, dtype=DataType.Float) + T103 = fd.ops.cast(T99, dtype=DataType.BFloat16) + T104 = fd.ops.neg(T100) + T105 = fd.ops.neg(T101) + T106 = fd.ops.sum(T102, dims=[0, 2], keepdim=False, dtype=DataType.Null) + T112 = fd.ops.broadcast_in_dim(T103, shape=[1, 1, 4096, 128], broadcast_dims=[0, 2, 3]) + T128 = fd.ops.slice(T58, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 4096, 128], strides=[1, 1, 1, 1], manual_normalization=0) + T129 = fd.ops.cast(T104, dtype=DataType.BFloat16) + T145 = fd.ops.slice(T59, start_indices=[0, 0, 0, 64], end_indices=[1, 8, 4096, 128], strides=[1, 1, 1, 1], manual_normalization=0) + T146 = fd.ops.cast(T105, dtype=DataType.BFloat16) + T147 = fd.ops.cast(T106, dtype=DataType.BFloat16) + T153 = fd.ops.broadcast_in_dim(T112, shape=[1, 32, 4096, 128], broadcast_dims=[0, 1, 2, 3]) + S154 = fd.define_scalar(0.00000, dtype=DataType.Double) + T164 = fd.ops.pad(T128, [0, 64, 0, 0, 0, 0, 0, 0], S154) + S165 = fd.define_scalar(0.00000, dtype=DataType.Double) + T175 = fd.ops.pad(T129, [64, 0, 0, 0, 0, 0, 0, 0], S165) + T181 = fd.ops.broadcast_in_dim(T112, shape=[1, 8, 4096, 128], broadcast_dims=[0, 1, 2, 3]) + S182 = fd.define_scalar(0.00000, dtype=DataType.Double) + T192 = fd.ops.pad(T145, [0, 64, 0, 0, 0, 0, 0, 0], S182) + S193 = fd.define_scalar(0.00000, dtype=DataType.Double) + T203 = fd.ops.pad(T146, [64, 0, 0, 0, 0, 0, 0, 0], S193) + T210 = fd.ops.broadcast_in_dim(T147, shape=[1, 8, 1, 4096, 128], broadcast_dims=[1, 3, 4]) + T211 = fd.ops.cast(T153, dtype=DataType.Float) + T212 = fd.ops.cast(T164, dtype=DataType.Float) + T213 = fd.ops.cast(T175, dtype=DataType.Float) + T214 = fd.ops.cast(T181, dtype=DataType.Float) + T215 = fd.ops.cast(T192, dtype=DataType.Float) + T216 = fd.ops.cast(T203, dtype=DataType.Float) + T217 = fd.ops.cast(T210, dtype=DataType.Float) + T218 = fd.ops.mul(T211, T52) + T219 = fd.ops.add(T213, T212) + T220 = fd.ops.mul(T214, T54) + T221 = fd.ops.add(T216, T215) + T222 = fd.ops.sum(T217, dims=[0, 2], keepdim=False, dtype=DataType.Null) + T223 = fd.ops.add(T219, T218) + T224 = fd.ops.add(T221, T220) + T225 = fd.ops.cast(T222, dtype=DataType.BFloat16) + T226 = fd.ops.cast(T223, dtype=DataType.BFloat16) + T227 = fd.ops.cast(T224, dtype=DataType.BFloat16) + T233 = fd.ops.broadcast_in_dim(T225, shape=[1, 8, 4096, 128], broadcast_dims=[1, 2, 3]) + T234 = fd.ops.permute(T226, dims=[0, 2, 1, 3]) + T235 = fd.ops.permute(T227, dims=[0, 2, 1, 3]) + T236 = fd.ops.permute(T233, dims=[0, 2, 1, 3]) + T241 = fd.ops.reshape(T234, new_shape=[1, 4096, 4096]) + T246 = fd.ops.reshape(T235, new_shape=[1, 4096, 1024]) + T251 = fd.ops.reshape(T236, new_shape=[1, 4096, 1024]) + fd.add_output(T251) + fd.add_output(T246) + fd.add_output(T241) +*/ +// clang-format on +TEST_P(MistralRopeTest, Bwd) { + const RopeConfig config = GetParam(); + config.verify(); + + const int64_t batch_size = config.batches; + const int64_t seq_len = config.seq_length; + const int64_t head_dim = config.head_size; + const int64_t num_attention_heads = config.n_head; + const int64_t num_key_value_heads = config.n_query_groups; + + std::vector shape0{ + batch_size, num_attention_heads, seq_len, head_dim}; + std::vector shape1{batch_size, head_dim / 2, seq_len}; + std::vector shape2{ + batch_size, num_attention_heads, seq_len, head_dim}; + std::vector shape3{ + batch_size, num_attention_heads, seq_len, head_dim}; + + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto T0 = makeContigConcreteTensor(shape0, DataType::BFloat16); + fusion.addInput(T0); + auto T1 = makeContigConcreteTensor(shape1, DataType::Float); + fusion.addInput(T1); + auto T2 = makeContigConcreteTensor(shape2, DataType::BFloat16); + fusion.addInput(T2); + auto T3 = makeContigConcreteTensor(shape3, DataType::BFloat16); + fusion.addInput(T3); + + auto T10 = reshape( + T0, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_key_value_heads), + IrBuilder::create(num_attention_heads / num_key_value_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + auto T11 = castOp(DataType::Float, T10); + auto T12 = sum(T11, {0, 2}); + auto T13 = permute(T1, {0, 2, 1}); + auto T14 = castOp(DataType::BFloat16, T12); + auto T15 = cat({T13, T13}, -1); + auto T22 = broadcast(T14, {true, false, true, false, false}); + auto T23 = sin(T15); + auto T24 = castOp(DataType::Float, T22); + auto T25 = castOp(DataType::BFloat16, T23); + auto T26 = sum(T24, {0, 2}); + auto T32 = broadcast(T25, {false, true, false, false}); + auto T33 = castOp(DataType::BFloat16, T26); + auto T39 = expand( + T32, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_attention_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + auto T45 = broadcast(T33, {true, false, false, false}); + auto T51 = expand( + T32, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_key_value_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + auto T52 = castOp(DataType::Float, T2); + auto T53 = castOp(DataType::Float, T39); + auto T54 = castOp(DataType::Float, T45); + auto T55 = castOp(DataType::Float, T51); + auto T56 = mul(T53, T52); + auto T57 = mul(T55, T54); + auto T58 = castOp(DataType::BFloat16, T56); + auto T59 = castOp(DataType::BFloat16, T57); + auto T75 = slice( + T58, + {{IrBuilder::create(0L), IrBuilder::create(batch_size)}, + {IrBuilder::create(0L), + IrBuilder::create(num_attention_heads)}, + {IrBuilder::create(0L), IrBuilder::create(seq_len)}, + {IrBuilder::create(0L), IrBuilder::create(head_dim / 2)}}); + auto T91 = slice( + T59, + {{IrBuilder::create(0L), IrBuilder::create(batch_size)}, + {IrBuilder::create(0L), + IrBuilder::create(num_key_value_heads)}, + {IrBuilder::create(0L), IrBuilder::create(seq_len)}, + {IrBuilder::create(0L), IrBuilder::create(head_dim / 2)}}); + auto T98 = reshape( + T3, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_key_value_heads), + IrBuilder::create(num_attention_heads / num_key_value_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + auto T99 = cos(T15); + auto T100 = castOp(DataType::Float, T75); + auto T101 = castOp(DataType::Float, T91); + auto T102 = castOp(DataType::Float, T98); + auto T103 = castOp(DataType::BFloat16, T99); + auto T104 = neg(T100); + auto T105 = neg(T101); + auto T106 = sum(T102, {0, 2}); + auto T112 = broadcast(T103, {false, true, false, false}); + auto T128 = slice( + T58, + {{IrBuilder::create(0L), IrBuilder::create(batch_size)}, + {IrBuilder::create(0L), + IrBuilder::create(num_attention_heads)}, + {IrBuilder::create(0L), IrBuilder::create(seq_len)}, + {IrBuilder::create(head_dim / 2), + IrBuilder::create(head_dim)}}); + auto T129 = castOp(DataType::BFloat16, T104); + auto T145 = slice( + T59, + {{IrBuilder::create(0L), IrBuilder::create(batch_size)}, + {IrBuilder::create(0L), + IrBuilder::create(num_key_value_heads)}, + {IrBuilder::create(0L), IrBuilder::create(seq_len)}, + {IrBuilder::create(head_dim / 2), + IrBuilder::create(head_dim)}}); + auto T146 = castOp(DataType::BFloat16, T105); + auto T147 = castOp(DataType::BFloat16, T106); + auto T153 = expand( + T112, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_attention_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + auto T164 = pad( + T128, {IrBuilder::create(0L), IrBuilder::create(head_dim / 2)}); + auto T175 = pad( + T129, {IrBuilder::create(head_dim / 2), IrBuilder::create(0L)}); + auto T181 = expand( + T112, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(num_key_value_heads), + IrBuilder::create(seq_len), + IrBuilder::create(head_dim)}); + auto T192 = pad( + T145, {IrBuilder::create(0L), IrBuilder::create(head_dim / 2)}); + auto T203 = pad( + T146, {IrBuilder::create(head_dim / 2), IrBuilder::create(0L)}); + auto T210 = broadcast(T147, {true, false, true, false, false}); + auto T211 = castOp(DataType::Float, T153); + auto T212 = castOp(DataType::Float, T164); + auto T213 = castOp(DataType::Float, T175); + auto T214 = castOp(DataType::Float, T181); + auto T215 = castOp(DataType::Float, T192); + auto T216 = castOp(DataType::Float, T203); + auto T217 = castOp(DataType::Float, T210); + auto T218 = mul(T211, T52); + auto T219 = add(T213, T212); + auto T220 = mul(T214, T54); + auto T221 = add(T216, T215); + auto T222 = sum(T217, {0, 2}); + auto T223 = add(T219, T218); + auto T224 = add(T221, T220); + auto T225 = castOp(DataType::BFloat16, T222); + auto T226 = castOp(DataType::BFloat16, T223); + auto T227 = castOp(DataType::BFloat16, T224); + auto T233 = broadcast(T225, {true, false, false, false}); + auto T234 = permute(T226, {0, 2, 1, 3}); + auto T235 = permute(T227, {0, 2, 1, 3}); + auto T236 = permute(T233, {0, 2, 1, 3}); + auto T241 = reshape( + T234, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(seq_len), + IrBuilder::create(num_attention_heads * head_dim)}); + auto T246 = reshape( + T235, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(seq_len), + IrBuilder::create(num_key_value_heads * head_dim)}); + auto T251 = reshape( + T236, + std::vector{ + IrBuilder::create(batch_size), + IrBuilder::create(seq_len), + IrBuilder::create(num_key_value_heads * head_dim)}); + fusion.addOutput(T251); + fusion.addOutput(T246); + fusion.addOutput(T241); + + auto options_fp32 = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_bf16 = + at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn(shape0, options_bf16); + auto t1 = at::randn(shape1, options_fp32); + auto t2 = at::randn(shape2, options_bf16); + auto t3 = at::randn(shape3, options_bf16); + std::vector inputs({t0, t1, t2, t3}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); +} + +using LitgptRopeTest = RopeTest; + +INSTANTIATE_TEST_SUITE_P( + , + LitgptRopeTest, + testing::Values( + RopeConfig{32, 128, 32, 128, 2, 4096}, // Llama2-7b-hf + RopeConfig{32, 128, 8, 128, 2, 8192}, // Llama3-8B + RopeConfig{4, 16, 4, 16, 2, 8}, // Small test config + RopeConfig{8, 16, 4, 16, 2, 8} // Small test config + ), + [](const testing::TestParamInfo& info) { + return info.param.toCompactString(); + }); + +TEST_P(LitgptRopeTest, Fwd) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + const RopeConfig config = GetParam(); + config.verify(); + + int64_t q_per_kv = config.n_head / config.n_query_groups; + int64_t total_qkv = q_per_kv + 2; + + int64_t rotation_num_splits = 2; + + std::vector shape_before_reshape{ + config.batches, + config.seq_length, + config.head_size * (config.n_head + 2 * config.n_query_groups)}; + std::vector shape_before_permutation{ + config.batches, + config.seq_length, + config.n_query_groups, + total_qkv, + config.head_size}; + std::vector shape_after_permutation{ + config.batches, + config.n_query_groups, + total_qkv, + config.seq_length, + config.head_size}; + std::vector shape_after_reshape{ + config.batches, + config.n_query_groups * total_qkv, + config.seq_length, + config.head_size}; + + const auto& input_shape = shape_before_reshape; + + // qkv after permutation + auto tv0 = makeContigConcreteTensor(input_shape, DataType::BFloat16); + fusion.addInput(tv0); + + // cos + auto tv1 = makeContigConcreteTensor( + {config.seq_length, config.rope_n_elem}, DataType::BFloat16); + fusion.addInput(tv1); + auto cos = tv1; + + // sin + auto tv2 = makeContigConcreteTensor( + {config.seq_length, config.rope_n_elem}, DataType::BFloat16); + fusion.addInput(tv2); + auto sin = tv2; + + auto zero = fusion.zeroVal(); + + auto qkv = reshape(tv0, shape_before_reshape, shape_before_permutation); + qkv = permute(qkv, {0, 2, 3, 1, 4}); + + std::vector slice_default_arg; + slice_default_arg.reserve(shape_after_permutation.size()); + for (const auto s : shape_after_permutation) { + slice_default_arg.push_back(Slice{zero, IrBuilder::create(s)}); + } + + int64_t qkv_slice_dim = 2; + + auto slice_arg_q = slice_default_arg; + slice_arg_q[qkv_slice_dim].stop = IrBuilder::create(total_qkv - 2); + + auto slice_arg_k = slice_default_arg; + slice_arg_k[qkv_slice_dim].start = IrBuilder::create(q_per_kv); + slice_arg_k[qkv_slice_dim].stop = IrBuilder::create(total_qkv - 1); + + auto apply_rope = [&](TensorView* x, + bool is_q, + std::vector slice_arg) -> TensorView* { + auto x_slice = slice(x, slice_arg); + + std::vector cur_shape = shape_after_permutation; + cur_shape[qkv_slice_dim] = is_q ? q_per_kv : 1; + std::vector new_shape{ + cur_shape[0], + config.n_query_groups * (is_q ? q_per_kv : 1), + config.seq_length, + config.rope_n_elem}; + x_slice = reshape(x_slice, cur_shape, new_shape); + + // x1 + std::vector x1_slice_arg; + x1_slice_arg.reserve(new_shape.size()); + for (const auto s : new_shape) { + x1_slice_arg.push_back(Slice{zero, IrBuilder::create(s)}); + } + + x1_slice_arg.back().stop = + IrBuilder::create(config.rope_n_elem / rotation_num_splits); + auto x1 = slice(x_slice, x1_slice_arg); + + // x2 + auto x2_slice_arg = x1_slice_arg; + x2_slice_arg.back().start = + IrBuilder::create(config.rope_n_elem / rotation_num_splits); + x2_slice_arg.back().stop = IrBuilder::create(config.rope_n_elem); + auto x2 = slice(x_slice, x2_slice_arg); + + auto rotated = cat({x2, x1}, -1); + + std::vector bcast_flags(new_shape.size(), false); + for (auto it = bcast_flags.begin(); + it != bcast_flags.begin() + (int64_t)bcast_flags.size() - 2; + ++it) { + *it = true; + } + auto cos_broadcast = broadcast(cos, bcast_flags); + auto sin_broadcast = broadcast(sin, bcast_flags); + + TensorView* out = + add(mul(x_slice, cos_broadcast), mul(rotated, sin_broadcast)); + out = castOp(DataType::BFloat16, out); + return out; + }; + + auto q_out = apply_rope(qkv, true, slice_arg_q); + fusion.addOutput(q_out); + + auto k_out = apply_rope(qkv, false, slice_arg_k); + fusion.addOutput(k_out); + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn(input_shape, options); + auto t1 = at::randn({config.seq_length, config.rope_n_elem}, options); + auto t2 = at::randn({config.seq_length, config.rope_n_elem}, options); + std::vector inputs({t0, t1, t2}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + +} // namespace nvfuser diff --git a/tests/cpp/test_segmentation.cpp b/tests/cpp/test_segmentation.cpp index b893c5c29ad..8bcf3b13d60 100644 --- a/tests/cpp/test_segmentation.cpp +++ b/tests/cpp/test_segmentation.cpp @@ -552,51 +552,6 @@ TEST_F(SegmentationTest, ForceBf16NotAllCast) { } } -// Test that a segment with a slice does not introduce a cast -// See https://github.com/NVIDIA/Fuser/pull/1936 -TEST_F(SegmentationTest, SliceSegmentCasts) { - EnableOptionsGuard opt_guard; - EnableOptionsGuard::getCurOptions().set(EnableOption::IoToLowerPrecision); - - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1, DataType::Half); - - fusion->addInput(tv0); - - // Group 1 - auto tv1 = mul(tv0, tv0); - // Group 2 - auto tv2 = slice(tv1, {0}, {3}, {1}); - auto tv3 = add(tv2, tv2); - - fusion->addOutput(tv3); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto in0 = at::randn({5}, options); - auto outputs = executor_cache.runFusionWithInputs({in0}); - - SegmentedFusion* segmented_fusion = - executor_cache.getMostRecentKernelRuntime()->fusionSegments(); - - ASSERT_EQ(segmented_fusion->edges().size(), 1); - - SegmentedEdge* slice_edge = segmented_fusion->edges().at(0); - - // Expect edge to be half-precision - // TODO: Change this rhs to DataType::Half once we have addressed - // https://github.com/NVIDIA/Fuser/issues/1902 - EXPECT_EQ(slice_edge->val->getDataType(), DataType::Float); - - // There should be no cast before the slice - EXPECT_TRUE(slice_edge->val->uses().at(0)->isA()); - - testValidate(executor_cache.fusion(), outputs, {in0}, __LINE__, __FILE__); -} - TEST_F(SegmentationTest, codeGenSupportedMergeIssue1970) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index 2eabde38b3b..3d95ad7d3c4 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -230,6 +230,46 @@ TEST_F(TensorFactoryTest, StandaloneIota) { } } +TEST_F(TensorFactoryTest, SimpleTriu) { + std::vector> input_sizes_2d = { + {64, 64}, {4, 16}, {16, 4}}; + std::vector> input_sizes_3d = {{16, 8, 32}}; + auto offsets = {0, 1, 2, -1, -2, 200, -200}; + + for (auto in : {input_sizes_2d, input_sizes_3d}) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv_to_triu_on = makeSymbolicTensor(in.at(0).size(), DataType::Half); + auto input_offset = IrBuilder::create(DataType::Int); + auto out = triu(tv_to_triu_on, input_offset); + + fusion->addInput(tv_to_triu_on); + fusion->addInput(input_offset); + fusion->addOutput(out); + + FusionExecutorCache executor_cache(std::move(fusion)); + + for (auto input_size : in) { + for (auto offset : offsets) { + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto in_tensor = at::randn(input_size, options); + + auto cg_outputs = + executor_cache.runFusionWithInputs({in_tensor, offset}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {in_tensor, offset}, + {at::triu(in_tensor, offset)}, + __LINE__, + __FILE__); + } + } + } +} + TEST_F(TensorFactoryTest, StandaloneARange) { auto starts_ends = {-1., 0., 10.3, 1024. * 256}; auto steps = {-1.5, 1., 2.}; diff --git a/tests/cpp/test_translate_mma.cpp b/tests/cpp/test_translate_mma.cpp index f290d30b576..ab6eb658cf6 100644 --- a/tests/cpp/test_translate_mma.cpp +++ b/tests/cpp/test_translate_mma.cpp @@ -43,10 +43,10 @@ namespace nvfuser { class CombineMulSumAsMmaTest : public NVFuserTest { void SetUp() override { // These test are enable for Turing and newer. Temporarily - // we are skipping Hopper since the matmul for it is under development. + // we are skipping Blackwell since the matmul for it is under development. auto lower_major = 8; auto lower_minor = 0; - auto upper_major = 9; + auto upper_major = 10; auto upper_minor = 0; if (cudaArchGuardShouldSkip( lower_major, lower_minor, upper_major, upper_minor)) { @@ -55,8 +55,14 @@ class CombineMulSumAsMmaTest : public NVFuserTest { << lower_minor << "and " << upper_major << "." << upper_minor << " to run.\n"; } + + pre_hopper = at::cuda::getCurrentDeviceProperties()->major < 9; + NVFuserTest::SetUp(); } + + protected: + bool pre_hopper; }; class CombineMulSumAsMmaTestWithLayout @@ -66,24 +72,30 @@ class CombineMulSumAsMmaTestWithLayout MmaLayout layout; void SetUp() override { layout = GetParam(); - // These test are enable for Turing and newer. Temporarily - // we are skipping Hopper since the matmul for it is under development. + // These test are enable for Turing and newer. + // we are skipping Blackwell since the matmul for it is under development. auto lower_major = 8; auto lower_minor = 0; - auto upper_major = 9; + auto upper_major = 10; auto upper_minor = 0; if (cudaArchGuardShouldSkip( lower_major, lower_minor, upper_major, upper_minor)) { - GTEST_SKIP() << "CombineMulSumAsMmaTest skipped " + GTEST_SKIP() << "CombineMulSumAsMmaTestWithLayout skipped " << "Requires GPU capability between " << lower_major << "." << lower_minor << "and " << upper_major << "." << upper_minor << " to run.\n"; } + pre_hopper = at::cuda::getCurrentDeviceProperties()->major < 9; NVFuserTest::SetUp(); } + + bool pre_hopper; }; -void performSubstitution(Fusion* fusion, bool should_not_find = false) { +void performSubstitution( + Fusion* fusion, + bool avoid_intermediates, + bool should_not_find = false) { EXPECT_TRUE(ir_utils::getOpsOfType(fusion).empty()); std::vector patterns = @@ -96,14 +108,14 @@ void performSubstitution(Fusion* fusion, bool should_not_find = false) { ASSERT_FALSE(patterns.empty()); EXPECT_EQ(patterns.size(), 1); - patterns.front().translateToMmaOp(); + patterns.front().translateToMmaOp(avoid_intermediates); ASSERT_FALSE(ir_utils::getOpsOfType(fusion).empty()); } // Test checks to see that the combiner can correctly replace // the mul-sum pair with a mma op. -TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Pass) { +TEST_P(CombineMulSumAsMmaTestWithLayout, MulSumToMatmul_Pass) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -119,13 +131,13 @@ TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Pass) { fusion.addOutput(tv3); - performSubstitution(&fusion); + performSubstitution(&fusion, /*avoid_intermediates=*/!pre_hopper); } // This test checks that the pattern matcher does not incorrectly identify // this mul-sum pair, as the mul is not fed by broadcasts ops; i.e. it is // not a matmul. -TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { +TEST_F(CombineMulSumAsMmaTest, MulSumToMatmul_Fail1) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(3, DataType::Half); @@ -138,11 +150,15 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { auto tv3 = sum(tv2, {-1}); fusion.addOutput(tv3); - performSubstitution(&fusion, /*should_not_find=*/true); + performSubstitution( + &fusion, /*avoid_intermediates=*/!pre_hopper, /*should_not_find=*/true); } // This fusion has Broadcast batch axes in each operand. -TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { +TEST_F(CombineMulSumAsMmaTest, MulSumToMatmul_MultipleBroadcasts) { + // This test expicitly broadcasts and transposes, so we cannot avoid + // intermediates on Hopper (yet). + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); // Assumes layout is kAllSupportedMmaLayout::NT; std::unique_ptr fusion_ptr = std::make_unique(); Fusion* fusion = fusion_ptr.get(); @@ -170,7 +186,8 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { auto tv3 = sum(tv2, {-1}); fusion->addOutput(tv3); - performSubstitution(fusion, /*should_not_find=*/false); + performSubstitution( + fusion, /*avoid_intermediates=*/!pre_hopper, /*should_not_find=*/false); // We test running this fusion also to verify that the broadcast batch // dimension does not cause unforeseen issues @@ -192,6 +209,7 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { // pair with a mma op, we are able to schedule it as we did with // a fusion that had a mma op to begin with. TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Schedule) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; @@ -209,7 +227,7 @@ TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Schedule) { fusion.addOutput(tv2); - performSubstitution(&fusion); + performSubstitution(&fusion, /*avoid_intermediates=*/!pre_hopper); MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); @@ -239,6 +257,7 @@ TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Schedule) { } TEST_P(CombineMulSumAsMmaTestWithLayout, UseMatmulScheduler) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; auto fusion = std::make_unique(); diff --git a/tests/cpp/test_vectorization_analysis.cpp b/tests/cpp/test_vectorization_analysis.cpp new file mode 100644 index 00000000000..6bfee4de52d --- /dev/null +++ b/tests/cpp/test_vectorization_analysis.cpp @@ -0,0 +1,267 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +// #include +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace nvfuser { + +namespace { + +void checkMappedVal( + const std::unordered_map& map, + TensorView* tv_target, + int64_t val) { + auto iter = map.find(tv_target); + EXPECT_TRUE(iter != map.end()); + if (iter != map.end()) { + EXPECT_EQ(iter->second->evaluate(), val); + } +} + +} // namespace + +using VectorizationAnalysisTest = NVFuserTest; + +// Simple pad test +TEST_F( + VectorizationAnalysisTest, + ContigInnerDimsMapperResizeFastestDimensionP2C) { + Fusion fusion; + FusionGuard fg(&fusion); + std::vector> expection_list; + + auto tv0 = makeContigConcreteTensor({4, 8, 16}); + fusion.addInput(tv0); + + // positive resize (+2, +2) + auto inner_pos = + pad(tv0, {IrBuilder::create(2L), IrBuilder::create(2L)}); + expection_list.emplace_back(std::make_pair(inner_pos, 2)); + fusion.addOutput(inner_pos); + + // positive uneven resize (+4, +2) + auto inner_pos_uneven = + pad(tv0, {IrBuilder::create(4L), IrBuilder::create(2L)}); + expection_list.emplace_back(std::make_pair(inner_pos_uneven, 2)); + fusion.addOutput(inner_pos_uneven); + + // positive large resize (+32, +32) + auto inner_pos_large = + pad(tv0, {IrBuilder::create(32L), IrBuilder::create(32L)}); + // projected extent is 16 + expection_list.emplace_back(std::make_pair(inner_pos_large, 16)); + fusion.addOutput(inner_pos_large); + + // negative resize (-2, -2) + auto inner_neg = + pad(tv0, {IrBuilder::create(-2L), IrBuilder::create(-2L)}); + expection_list.emplace_back(std::make_pair(inner_neg, 2)); + fusion.addOutput(inner_neg); + + // negative uneven resize (-2, -4) + auto inner_neg_uneven = + pad(tv0, {IrBuilder::create(-2L), IrBuilder::create(-4L)}); + expection_list.emplace_back(std::make_pair(inner_neg_uneven, 2)); + fusion.addOutput(inner_neg_uneven); + + // negative large resize to zero (-8, -8) + auto inner_neg_large = + pad(tv0, {IrBuilder::create(-8L), IrBuilder::create(-8L)}); + // output id with extent 0 cannot be vectorized + expection_list.emplace_back(std::make_pair(inner_neg_large, 0)); + fusion.addOutput(inner_neg_large); + + // uneven resize (-2, 4) + auto inner_uneven = + pad(tv0, {IrBuilder::create(-2L), IrBuilder::create(4L)}); + expection_list.emplace_back(std::make_pair(inner_uneven, 2)); + fusion.addOutput(inner_uneven); + + // one side resize (0, 4) + auto inner_one_size = + pad(tv0, {IrBuilder::create(0L), IrBuilder::create(4L)}); + // resize extent of 0 wouldn't affect vectorization factor + expection_list.emplace_back(std::make_pair(inner_one_size, 4)); + fusion.addOutput(inner_one_size); + + std::unordered_map projected_extent_map = + vectorize_helper::ContiguousInnerDimensionsMapper::map( + tv0, tv0->getLogicalDomain()) + .getTvToContigMergeOfInnerSizeMap(); + + for (const auto& [tv, val] : expection_list) { + checkMappedVal(projected_extent_map, tv, val); + } +} + +// Simple pad test +TEST_F( + VectorizationAnalysisTest, + ContigInnerDimsMapperResizeFastestDimensionC2P) { + Fusion fusion; + FusionGuard fg(&fusion); + std::vector> expection_list; + + auto tv0 = makeContigConcreteTensor({4, 8, 8}); + fusion.addInput(tv0); + // positive resize (+24, +24) + auto tv1 = + pad(tv0, {IrBuilder::create(24L), IrBuilder::create(24L)}); + fusion.addOutput(tv1); + + // negative resize to zero (-4, -4) + auto tv2 = + pad(tv0, {IrBuilder::create(-4), IrBuilder::create(-4L)}); + fusion.addOutput(tv2); + + std::unordered_map projected_extent_map_from_tv1 = + vectorize_helper::ContiguousInnerDimensionsMapper::map( + tv1, tv1->getLogicalDomain()) + .getTvToContigMergeOfInnerSizeMap(); + checkMappedVal(projected_extent_map_from_tv1, tv0, 8); + checkMappedVal(projected_extent_map_from_tv1, tv2, 0); + + // because tv2's fastest dimension is resized to 0 + std::unordered_map projected_extent_map_from_tv2 = + vectorize_helper::ContiguousInnerDimensionsMapper::map( + tv2, tv2->getLogicalDomain()) + .getTvToContigMergeOfInnerSizeMap(); + checkMappedVal(projected_extent_map_from_tv2, tv0, 0); + checkMappedVal(projected_extent_map_from_tv2, tv1, 0); +} + +TEST_F(VectorizationAnalysisTest, ContigInnerDimsMapperResizeMiddleDimension) { + Fusion fusion; + FusionGuard fg(&fusion); + std::vector> expection_list; + + auto tv0 = makeContigConcreteTensor({4, 8, 16}); + fusion.addInput(tv0); + + // positive resize (+2, +2) + auto middle_pos = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(2L), + IrBuilder::create(2L)}); + expection_list.emplace_back(std::make_pair(middle_pos, 2 * 16)); + fusion.addOutput(middle_pos); + + // negative resize (-2, -2) + auto middle_neg = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(-2L), + IrBuilder::create(-2L)}); + expection_list.emplace_back(std::make_pair(middle_neg, 2 * 16)); + fusion.addOutput(middle_neg); + + std::unordered_map projected_extent_map = + vectorize_helper::ContiguousInnerDimensionsMapper::map( + tv0, tv0->getLogicalDomain()) + .getTvToContigMergeOfInnerSizeMap(); + for (const auto& [tv, val] : expection_list) { + checkMappedVal(projected_extent_map, tv, val); + } +} + +TEST_F( + VectorizationAnalysisTest, + ContigInnerDimsMapperResizeMultipleDimension) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({4, 8, 32}); + fusion.addInput(tv0); + + // the inner most dimension of resize would participate in vectorization + auto tv1 = + pad(tv0, + {IrBuilder::create(8L), + IrBuilder::create(8L), + IrBuilder::create(4L), + IrBuilder::create(4L)}); + fusion.addOutput(tv1); + + std::unordered_map projected_extent_map_from_producer = + vectorize_helper::ContiguousInnerDimensionsMapper::map( + tv0, tv0->getLogicalDomain()) + .getTvToContigMergeOfInnerSizeMap(); + checkMappedVal(projected_extent_map_from_producer, tv1, 8); + + std::unordered_map projected_extent_map_from_consumer = + vectorize_helper::ContiguousInnerDimensionsMapper::map( + tv1, tv1->getLogicalDomain()) + .getTvToContigMergeOfInnerSizeMap(); + checkMappedVal(projected_extent_map_from_consumer, tv0, 8); +} + +TEST_F(VectorizationAnalysisTest, ContigInnerDimsMapperResizeStacked) { + Fusion fusion; + FusionGuard fg(&fusion); + std::vector> expection_list; + + auto tv0 = makeContigConcreteTensor({4, 8, 36}); + fusion.addInput(tv0); + // resize on different dimension + auto tv1 = + pad(tv0, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(-4L), + IrBuilder::create(-4L)}); + auto tv2 = + pad(tv1, + {IrBuilder::create(0L), + IrBuilder::create(0L), + IrBuilder::create(-2L), + IrBuilder::create(-2L)}); + // only the inner most resize is included in vectorization analysis + expection_list.emplace_back(std::make_pair(tv2, 2 * 36)); + fusion.addOutput(tv2); + + // resize on the same dimension, squeeze size to zero + auto tv3 = + pad(tv0, {IrBuilder::create(-9L), IrBuilder::create(-9L)}); + auto tv4 = + pad(tv3, {IrBuilder::create(-9L), IrBuilder::create(-9L)}); + // output id with extent 0 cannot be vectorized + expection_list.emplace_back(std::make_pair(tv4, 0)); + fusion.addOutput(tv4); + + // resize on the same dimension + auto tv5 = + pad(tv0, {IrBuilder::create(-6L), IrBuilder::create(-6L)}); + auto tv6 = pad(tv5, {IrBuilder::create(9L), IrBuilder::create(9L)}); + // two resize operation would stack + expection_list.emplace_back(std::make_pair(tv6, 3)); + fusion.addOutput(tv6); + + std::unordered_map projected_extent_map = + vectorize_helper::ContiguousInnerDimensionsMapper::map( + tv0, tv0->getLogicalDomain()) + .getTvToContigMergeOfInnerSizeMap(); + for (const auto& [tv, val] : expection_list) { + checkMappedVal(projected_extent_map, tv, val); + } +} + +} // namespace nvfuser diff --git a/tests/python/test_matmul.py b/tests/python/test_matmul.py index 14176154e1f..ed1b12c2063 100644 --- a/tests/python/test_matmul.py +++ b/tests/python/test_matmul.py @@ -4,7 +4,7 @@ # Owner(s): ["module: nvfuser"] import torch -from utils import NVFuserTest, is_pre_volta +from utils import NVFuserTest, is_pre_volta, verify_stride_order from nvfuser import FusionDefinition, DataType import pytest from functools import partial @@ -201,3 +201,29 @@ def fusion_func(fd: FusionDefinition) -> None: ] outputs, _ = self.exec_nvfuser(fusion_func, inputs) assert outputs[0].ndim == 3 + + def test_matmul_stride(self): + n, h, l, s, e = 4, 8, 16, 16, 8 + inputs = [ + torch.randn( + n, h, l, e, device="cuda", dtype=torch.float16, requires_grad=True + ), + torch.randn( + n, h, s, e, device="cuda", dtype=torch.float16, requires_grad=True + ), + ] + for perm in itertools.permutations(range(4), 4): + + def fusion_func(fd: FusionDefinition) -> None: + q = fd.from_pytorch(inputs[0]) + k = fd.from_pytorch(inputs[1]) + k_t = fd.ops.permute(k, [0, 1, 3, 2]) + out = fd.ops.matmul(q, k_t) + fd.add_output(out, stride_order=perm) + + with FusionDefinition() as fd: + fusion_func(fd) + nvf_out = fd.execute(inputs) + eager_out = torch.matmul(inputs[0], torch.transpose(inputs[1], -2, -1)) + verify_stride_order(nvf_out[0].stride(), perm) + torch.testing.assert_close(nvf_out[0], eager_out) diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index e46729c594c..0b74fddeae6 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -3331,7 +3331,10 @@ def fusion_func(fd: FusionDefinition) -> None: fd.add_output(T27) fd.add_output(T28) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True) + # TODO: Support segmentation. See #3594. + nvf_out, _ = self.exec_nvfuser( + fusion_func, inputs, is_clonable=True, supports_segmentation=False + ) t12 = inputs[1] * inputs[-2] t13 = torch.permute(t12, [0, 1, 3, 2]) @@ -4383,7 +4386,10 @@ def fusion_func(fd: FusionDefinition): fd.add_output(T18) fd.add_output(T16) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True) + # TODO: Support segmentation. See #3594. + nvf_out, _ = self.exec_nvfuser( + fusion_func, inputs, is_clonable=True, supports_segmentation=False + ) def test_returning_aliased_outputs(self): inputs = [torch.randn((1, 2, 3, 4), dtype=torch.float32, device="cuda:0")] @@ -4714,7 +4720,7 @@ def fusion_func(fd: FusionDefinition) -> None: # extents range from [-1, -6]. self.assertEqual(fd.extents(), [idx for idx in range(-1, -7, -1)]) - def test_issue_3292(self): + def test_issue3292(self): inputs = [ torch.testing.make_tensor( (5, 5, 576), dtype=torch.float32, device="cuda:0" @@ -4827,3 +4833,32 @@ def fusion_func(fd: FusionDefinition, inps) -> None: # Serializing error test cases corrupts the serialized binary causing subsequent tests to fail. # Reset the fusion cache to avoid this. FusionCache.reset() + + def test_issue1279(self): + inputs = [ + torch.randn(2, 1, 2, dtype=torch.float16, device="cuda:0"), + ] + + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, 1, -1], + contiguity=[True, None, True], + dtype=DataType.Half, + is_cpu=False, + ) + T4 = fd.ops.cast(T0, dtype=DataType.Float) + T5, T6 = fd.ops.var_mean(T4, dims=[1], correction=1, keepdim=False) + T7 = fd.ops.cast(T5, dtype=DataType.Half) + T8 = fd.ops.cast(T6, dtype=DataType.Half) + fd.add_output(T7) + fd.add_output(T8) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + + a = inputs[0].type(torch.float32) + b, c = torch.var_mean(a, dim=1) + d = b.type(torch.float16) + e = c.type(torch.float16) + + self.assertEqual(nvf_out[0], d) + self.assertEqual(nvf_out[1], e) diff --git a/tests/python/test_sdpa.py b/tests/python/test_sdpa.py index 95da686d060..87cf02f0c52 100644 --- a/tests/python/test_sdpa.py +++ b/tests/python/test_sdpa.py @@ -3,13 +3,14 @@ # SPDX-License-Identifier: BSD-3-Clause # Owner(s): ["module: nvfuser"] -import torch -from utils import NVFuserTest, is_pre_ampere -from nvfuser import FusionDefinition, DataType, FusionCache -import pytest import itertools -from functools import partial +import math +import pytest +import torch import torch.nn.functional as F +from functools import partial +from nvfuser import FusionDefinition, DataType, FusionCache +from utils import NVFuserTest, is_pre_ampere @pytest.mark.skipif( @@ -17,6 +18,48 @@ reason="Flash Attention is only supported on Ampere and newer devices.", ) class TestSdpa(NVFuserTest): + def test_softmax_logsumexp(self): + def fusion_func(fd: FusionDefinition) -> None: + q = fd.define_tensor( + shape=[-1, -1, -1, -1], + dtype=DataType.BFloat16, + ) + k = fd.define_tensor( + shape=[-1, -1, -1, -1], + dtype=DataType.BFloat16, + ) + v = fd.define_tensor( + shape=[-1, -1, -1, -1], + dtype=DataType.BFloat16, + ) + ( + _, + lse, + *_, + ) = fd.ops.sdpfa_fwd(q, k, v, dropout_p=None, is_causal=None, scale=None) + fd.add_output(lse) + + n, h, l, s, e = 1, 1, 4, 4, 2 + inputs = [ + torch.ones((n, h, l, e), dtype=torch.bfloat16, device="cuda"), + torch.ones((n, h, s, e), dtype=torch.bfloat16, device="cuda"), + torch.ones((n, h, s, e), dtype=torch.bfloat16, device="cuda"), + ] + + from torch.nn.attention import SDPBackend, sdpa_kernel + + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + nvf_out, _ = self.exec_nvfuser( + fusion_func, + inputs, + ) + # Ignoring size-1 dimensions, `q @ k^T / sqrt(e)` generates a `l`x`s` + # matrix full of `sqrt(e)`s. Therefore, the logsumexp of each row is + # expected to be log(exp(sqrt(e)) * s) = log(s) + sqrt(e). + torch.testing.assert_close( + nvf_out[0].cpu(), torch.full((n, h, l), math.log(s) + e**0.5) + ) + def test_sdpa_fwd(self): def fusion_func( fd: FusionDefinition, has_dropout: bool, has_causal: bool, has_scale: bool @@ -46,9 +89,7 @@ def fusion_func( is_causal = fd.define_scalar(value=None, dtype=DataType.Bool) if has_scale: scale = fd.define_scalar(value=None, dtype=DataType.Double) - attn, *intermediate_results = fd.ops.sdpfa_fwd( - q, k, v, dropout_p, is_causal, scale - ) + attn, *_ = fd.ops.sdpfa_fwd(q, k, v, dropout_p, is_causal, scale) fd.add_output(attn) N, H, L, S, E = 4, 8, 16, 16, 8 diff --git a/tests/python/test_transformer_engine.py b/tests/python/test_transformer_engine.py index 00eb4b9eeb7..5c71633ecf1 100644 --- a/tests/python/test_transformer_engine.py +++ b/tests/python/test_transformer_engine.py @@ -59,7 +59,22 @@ def setup_process_group(mpi_test) -> None: [Parallelism.TENSOR_PARALLEL, Parallelism.SEQUENCE_PARALLEL], ids=["tp", "sp"], ) -def test_transformer_layer(setup_process_group, benchmark, compute_type, parallelism): +@pytest.mark.parametrize( + "overlap", + [False, True], + ids=["nonoverlap", "overlap"], +) +def test_transformer_layer( + setup_process_group, + monkeypatch, + benchmark, + compute_type: ComputeType, + parallelism: Parallelism, + overlap: bool, +): + if overlap and parallelism == Parallelism.TENSOR_PARALLEL: + pytest.skip("Tensor parallelism doesn't support overlapping") + # Hyperparameters for GPT-3 hidden_size = 12288 num_heads = 96 @@ -77,11 +92,13 @@ def test_transformer_layer(setup_process_group, benchmark, compute_type, paralle hidden_size, ffn_hidden_size, num_heads, - # https://github.com/NVIDIA/TransformerEngine/issues/1350: the - # benchmark fails to execute on H100 with the default format (SBHD). + # According to https://github.com/NVIDIA/TransformerEngine/issues/1350, + # `attn_input_format` has to match the format of `transformer_layer`'s + # input. attn_input_format="bshd", set_parallel_mode=True, sequence_parallel=(parallelism == Parallelism.SEQUENCE_PARALLEL), + ub_tp_comm_overlap=overlap, tp_group=dist.group.WORLD, ) transformer_layer.to(dtype).to("cuda") @@ -97,6 +114,20 @@ def test_transformer_layer(setup_process_group, benchmark, compute_type, paralle batch_size, local_sequence_length, hidden_size, dtype=dtype, device="cuda" ) + if overlap: + # Similar to https://github.com/NVIDIA/TransformerEngine/blob/e7bfc0c547d63332e4f8d65e606dc69f4c22ffbe/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py#L27-L29 + monkeypatch.setenv("CUDA_DEVICE_MAX_CONNECTIONS", "1") + if not te.cpp_extensions.device_supports_multicast(): + monkeypatch.setenv("UB_SKIPMC", "1") + + te.module.base.initialize_ub( + # Instructed by https://github.com/NVIDIA/TransformerEngine/blob/e7bfc0c547d63332e4f8d65e606dc69f4c22ffbe/transformer_engine/pytorch/module/base.py#L96-L99 + [batch_size * sequence_length, hidden_size], + size, + dtype=dtype, + bootstrap_backend="nccl", + ) + match compute_type: case ComputeType.FORWARD: @@ -155,3 +186,6 @@ def benchmark_fn(y, dy, profile): setup=partial(setup_fn, True), rounds=5, ) + + if overlap: + te.module.base.destroy_ub() diff --git a/version.txt b/version.txt index 521eb3d6e6c..ac16615536e 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.23 +0.2.24