diff --git a/CMakeLists.txt b/CMakeLists.txt index 4dda7636678..f26d7ada7f9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,8 +231,10 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/registry.cpp ${NVFUSER_SRCS_DIR}/scheduler/registry_utils.cpp + ${NVFUSER_SRCS_DIR}/scheduler/resize.cpp ${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp ${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp + ${NVFUSER_SRCS_DIR}/scheduler/tools/domain_map.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/maxinfo_propagator.cpp diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 794de6f9d92..06fa56bb085 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3740,10 +3740,10 @@ void TensorDomain::setAllocationDomain( std::vector TensorDomain::allIDs() const { std::array*, 6> all_domains = { + &loop_domain_, &logical_domain_, &root_domain_, &initial_loop_domain_, - &loop_domain_, &allocation_domain_, &additional_ids_}; VectorOfUniqueEntries discovered_ids; diff --git a/csrc/options.cpp b/csrc/options.cpp index 639e0c57622..f53ef79893d 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -154,15 +154,16 @@ const std::unordered_map& getEnableOptions() { {"fuse_matmul", EnableOption::FuseMatmul}, {"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls}, {"id_model", EnableOption::IdModel}, + {"io_to_lower_precision", EnableOption::IoToLowerPrecision}, {"kernel_db", EnableOption::KernelDb}, + {"kernel_debug", EnableOption::KernelDebug}, + {"kernel_lineinfo", EnableOption::KernelLineInfo}, {"kernel_profile", EnableOption::KernelProfile}, {"memory_promotion", EnableOption::MemoryPromotion}, {"reuse_zeroed_memory", EnableOption::ReuseZeroedMemory}, + {"resize_scheduler", EnableOption::ResizeScheduler}, {"static_fusion_count", EnableOption::StaticFusionCount}, {"warn_register_spill", EnableOption::WarnRegisterSpill}, - {"io_to_lower_precision", EnableOption::IoToLowerPrecision}, - {"kernel_debug", EnableOption::KernelDebug}, - {"kernel_lineinfo", EnableOption::KernelLineInfo}, }; return available_options; } diff --git a/csrc/options.h b/csrc/options.h index 8d69719897c..d10e739af60 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -93,17 +93,18 @@ enum class EnableOption { FuseMatmul, //! Enable automatic fusion of matmul and linear ops FuseMultipleMatmuls, //! Allow fusing more than one matmul in a single kernel IdModel, //! Enable IdModel - KernelDb, //! Enable Kernel Database - KernelProfile, //! Enable intra-kernel performance profiling - MemoryPromotion, //! Enable promotion of memory types for non-pointwise ops - StaticFusionCount, //! Enable using single static count in kernel name - ReuseZeroedMemory, //! Re-use zeroed memory used for grid synchronization - WarnRegisterSpill, //! Enable warnings of register spill IoToLowerPrecision, //! Enable castInputOutputToLowerPrecision. #1889 explains //! why we disabled it by default. + KernelDb, //! Enable Kernel Database KernelDebug, //! Enable debug mode in nvrtc KernelLineInfo, //! Embed line info to compiled kernel, and dump the full CUDA //! C++ code + KernelProfile, //! Enable intra-kernel performance profiling + MemoryPromotion, //! Enable promotion of memory types for non-pointwise ops + ReuseZeroedMemory, //! Re-use zeroed memory used for grid synchronization + ResizeScheduler, //! Enable the resize scheduler + StaticFusionCount, //! Enable using single static count in kernel name + WarnRegisterSpill, //! Enable warnings of register spill EndOfOption //! Placeholder for counting the number of elements }; diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 613a084f36f..57e7624e772 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -823,7 +823,8 @@ void initNvFuserPythonBindings(PyObject* module) { .value("inner_outer_persistent", SchedulerType::InnerOuterPersistent) .value("outer_persistent", SchedulerType::OuterPersistent) .value("transpose", SchedulerType::Transpose) - .value("expr_eval", SchedulerType::ExprEval); + .value("expr_eval", SchedulerType::ExprEval) + .value("resize", SchedulerType::Resize); nvfuser.def("compute_contiguity", computeContiguity); nvfuser.def("compute_tensor_descriptor", computeTensorDescriptor); diff --git a/csrc/scheduler/compile_time_info.h b/csrc/scheduler/compile_time_info.h index f7ec9d4a97f..3436bcd70eb 100644 --- a/csrc/scheduler/compile_time_info.h +++ b/csrc/scheduler/compile_time_info.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -54,7 +55,7 @@ enum class CompileTimeEntryType { //! stores the domain map of a fusion. class DomainMap { public: - using DataType = pointwise_utils::DomainMap; + using DataType = scheduler_tools::DomainMap; static const CompileTimeEntryType EntryType = CompileTimeEntryType::DOMAIN_MAP; }; @@ -63,7 +64,7 @@ class DomainMap { //! stores the domain map of a fusion, used by transpose scheduler. class TransposeDomainMap { public: - using DataType = pointwise_utils::DomainMap; + using DataType = scheduler_tools::DomainMap; static const CompileTimeEntryType EntryType = CompileTimeEntryType::TRANSPOSE_DOMAIN_MAP; }; diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index bc7a0fb32c6..9b5068c04c8 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -29,37 +29,6 @@ namespace { // Unused at the moment, commenting for clang tidy constexpr int64_t kThreadX = 128; -class DomainMap : public pointwise_utils::DomainMap { - public: - using pointwise_utils::DomainMap::DomainMap; - - // The pointwise scheduler heuristics requires a minimum number of axes. - // The output reference tensor should respect this requirement. - TensorView* findReferenceTensorView(int64_t minimum_num_axes = 0) const { - TensorView* result = nullptr; - int64_t max_dims = -1; - for (auto output_tv : - ir_utils::filterByType(fusion_->outputs())) { - if (isValidReference(output_tv) && - hasMinimumSize(output_tv, minimum_num_axes) && - !output_tv->isFusionInput()) { - int64_t n_dims = pointwise_utils::nRootDims(output_tv); - if (n_dims > max_dims) { - result = output_tv; - max_dims = n_dims; - } - } - } - return result; - } - - private: - bool hasMinimumSize(TensorView* tv, int64_t num_axes) const { - NVF_ERROR(tv != nullptr); - return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes); - } -}; - } // namespace std::unique_ptr getPointwiseHeuristics( @@ -79,14 +48,17 @@ std::unique_ptr getPointwiseHeuristics( auto domain_map_entry = HeuristicDataCacheEntry( - data_cache, - [fusion]() { return std::make_unique(fusion); }); - const auto& domain_map = dynamic_cast(domain_map_entry.get()); + data_cache, [fusion]() { + return std::make_unique( + fusion); + }); + const auto& domain_map = dynamic_cast( + domain_map_entry.get()); auto largest_out_entry = HeuristicDataCacheEntry( data_cache, [&domain_map]() { - std::vector data{domain_map.findReferenceTensorView()}; + std::vector data{domain_map.findReferenceTensor()}; return std::make_unique>(std::move(data)); }); TensorView* largest_out = largest_out_entry.get()[0]; @@ -432,19 +404,11 @@ std::unique_ptr getPointwiseHeuristics( return params; } -// Return reference tensor view. -TensorView* getReferenceTensorView(Fusion* fusion) { - FusionGuard fg(fusion); - DomainMap domain_map(fusion); - auto reference_tv = domain_map.findReferenceTensorView(); - return reference_tv; -} - //! Utility for canSchedule interface to check if this fusion has //! a fully broadcasted reference tensor, which is necessary for //! the pointwise scheduler. bool hasReferenceTensorView(Fusion* fusion) { - return getReferenceTensorView(fusion) != nullptr; + return pointwise_utils::getReferenceTensor(fusion) != nullptr; } bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) { @@ -541,7 +505,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { return; } - TensorView* reference_tv = getReferenceTensorView(fusion); + TensorView* reference_tv = pointwise_utils::getReferenceTensor(fusion); NVF_ERROR( reference_tv != nullptr, diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index 2f4f119fc46..be100a9f54d 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -5,247 +5,35 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include -#include #include -#include - -#include namespace nvfuser { namespace pointwise_utils { -namespace { - -// Grab all exact set mappings from consumer to producer domains of -// indexed accesses, e.g., index_select -std::unordered_multimap< - std::shared_ptr>, - std::shared_ptr>> -getIndexedConsumerToProducerMap(Fusion* fusion, const ComputeAtMap& ca_map) { - std::unordered_multimap< - std::shared_ptr>, - std::shared_ptr>> - indexed_id_map; - - for (auto expr : fusion->exprs()) { - if (auto gather = dynamic_cast(expr)) { - auto p_id = gather->getIndexedID(); - auto c_id = gather->getConsumerOfIndexedID(); - indexed_id_map.emplace( - ca_map.disjointSetOf(c_id, IdMappingMode::EXACT), - ca_map.disjointSetOf(p_id, IdMappingMode::EXACT)); - } else if (auto index_select = dynamic_cast(expr)) { - auto p_id = index_select->getIndexedID(); - auto c_id = index_select->getConsumerOfIndexedID(); - indexed_id_map.emplace( - ca_map.disjointSetOf(c_id, IdMappingMode::EXACT), - ca_map.disjointSetOf(p_id, IdMappingMode::EXACT)); - } else { - // Note there's no consumer ID for select. This means we can't - // just propagate from consumers to indexed producers. It seems - // it's necessary to schedule producers and consumers separately - // in those cases. - continue; - } - } - - return indexed_id_map; -} - -// Check if a root ID of a fusion input tensor that is indirectly -// accessed by ops such as torchGather needs to be mapped with -// a reference tensor. Select has a similar effect as squeeze as the -// indexed domain is removed, so the domain does not need to be mapped -// as long as the tensor is a fusion input. Similarly, in index_select -// and torchGather, if the output domain is a broadcast, it does not -// need to be mapped if not resolved. -bool canIgnoreIndexedInputDomainID( - TensorView* input_tv, - IterDomain* root_id, - const ComputeAtMap& ca_map) { - NVF_ERROR(input_tv->isFusionInput()); - for (auto use : input_tv->uses()) { - if (auto select = dynamic_cast(use)) { - if (root_id != select->getIndexedID()) { - return false; - } - } else if (auto index_select = dynamic_cast(use)) { - // If the root_id is an indexed ID, and the consumer ID may be a - // broadcast. In that case, nothing needs to be mapped if the - // consumer broadcast is not resolved - if (root_id != index_select->getIndexedID() || - !ca_map - .getConcreteMappedID( - index_select->getConsumerOfIndexedID(), - IdMappingMode::PERMISSIVE) - ->isBroadcast()) { - return false; - } - } else if (auto gather = dynamic_cast(use)) { - // TODO: Remove this. Once slice is used for torchGather, this - // should not be necessary. For now, it is necessary to not - // break the existing torchGather tests - if (!gather->exactSizes()) { - continue; +TensorView* PointwiseDomainMap::findReferenceTensor( + int64_t minimum_num_axes) const { + TensorView* result = nullptr; + int64_t max_dims = -1; + for (auto output_tv : + ir_utils::filterByType(fusion_->outputs())) { + if (isValidReference(output_tv) && + hasMinimumSize(output_tv, minimum_num_axes) && + !output_tv->isFusionInput()) { + int64_t n_dims = pointwise_utils::nRootDims(output_tv); + if (n_dims > max_dims) { + result = output_tv; + max_dims = n_dims; } - // If the root_id is an indexed ID, and the consumer ID may be a - // broadcast. In that case, nothing needs to be mapped if the - // consumer broadcast is not resolved - if (root_id != gather->getIndexedID() || - !ca_map - .getConcreteMappedID( - gather->getConsumerOfIndexedID(), IdMappingMode::PERMISSIVE) - ->isBroadcast()) { - return false; - } - } else { - // If the input TV is used by any other ops - return false; } } - - return true; + return result; } -} // namespace - -DomainMap::DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(fusion) { - tvs_with_rfactor_ = scheduler_utils::getTVsWithNonReductionRFactor(fusion); -} - -// Determine if all IterDomains in input are mapped to the given tensor -bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) - const { - // Get concrete IDs for input root or logical domain - std::unordered_set in_concrete_ids; - for (auto in_id : input_tv->getLogicalDomain()) { - if (canIgnoreIndexedInputDomainID(input_tv, in_id, ca_map_)) { - continue; - } - - // Permissive map is required for the transpose scheduler to support cases - // like T0[I0, b] + T1[b, I1] - auto concrete = - ca_map_.getConcreteMappedID(in_id, IdMappingMode::PERMISSIVE); - - if (!concrete->isBroadcast() && !in_id->isReduction()) { - in_concrete_ids.insert(concrete); - } - } - - // Erase all input concrete IDs mapped to the output domain - // Ignore unresolved broadcast dimensions - eraseifInputMappedThroughRootDomainAndIndexing( - in_concrete_ids, tv->getLogicalDomain()); - - return in_concrete_ids.empty(); -} - -// Reference domains must exactly match with the input domains. See -// also PR #661 -IterDomain* DomainMap::getMappedInputConcreteID( - const std::unordered_set& in_concrete_ids, - IterDomain* out_id) const { - auto in_concrete_id_iter = std::find_if( - in_concrete_ids.begin(), - in_concrete_ids.end(), - [&](IterDomain* in_concrete_id) { - return ca_map_.areMapped(in_concrete_id, out_id, IdMappingMode::EXACT); - }); - if (in_concrete_id_iter != in_concrete_ids.end()) { - return *in_concrete_id_iter; - } else { - return nullptr; - } -} - -// Erase input concrete ID if it is mapped to output ID -bool DomainMap::eraseIfMapped( - std::unordered_set& in_concrete_ids, - IterDomain* out_id) const { - auto mapped_input_conrete_id = - getMappedInputConcreteID(in_concrete_ids, out_id); - if (mapped_input_conrete_id != nullptr) { - in_concrete_ids.erase(mapped_input_conrete_id); - return true; - } else { - return false; - } -} - -void DomainMap::eraseifInputMappedThroughRootDomainAndIndexing( - std::unordered_set& in_ids, - const std::vector& ids) const { - // Use ComputeAtMap::getAllDisjointSetProducers to grab all producer - // IDs through rfactor exprs - VectorOfUniqueEntries>> - exact_sets; - std::for_each(ids.begin(), ids.end(), [&](IterDomain* id) { - exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); - }); - - // Traverse through indexed domains. - const auto indexed_id_multimap = - getIndexedConsumerToProducerMap(fusion_, ca_map_); - - VectorOfUniqueEntries>> - all_exact_sets_covered; - - // Back traverses through the exact map and indexed - // producer-consumer pairs - for (auto current_sets = exact_sets; !current_sets.empty();) { - auto producer_sets = ca_map_.getAllDisjointSetProducers(current_sets); - all_exact_sets_covered.pushBack(producer_sets); - - current_sets.clear(); - - // Further traversal if any of the new producer sets is a producer - // of indexed domains - for (const auto& producer_set : producer_sets) { - auto indexed_id_multimap_range = - indexed_id_multimap.equal_range(producer_set); - for (auto producer_of_producer_it = indexed_id_multimap_range.first; - producer_of_producer_it != indexed_id_multimap_range.second; - ++producer_of_producer_it) { - current_sets.pushBack(producer_of_producer_it->second); - } - } - } - - for (const auto& exact_set_ptr : all_exact_sets_covered) { - auto exact_concrete_id = ca_map_.getConcreteMappedID( - exact_set_ptr->front(), IdMappingMode::EXACT); - eraseIfMapped(in_ids, exact_concrete_id); - } -} - -// Find any id in domain that maps with target id -IterDomain* DomainMap::anyMapped( - const std::vector& domain, - IterDomain* target) const { - for (auto id : domain) { - if (ca_map_.areMapped(id, target, IdMappingMode::EXACT)) { - return id; - } - } - return nullptr; -} - -// Determine if output TensorView is a valid reference tensor for this fusion. -// The reference tensor must map to all the iterDomains in each input. -bool DomainMap::isValidReference(TensorView* tv) const { - for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { - if (input_tv->uses().empty()) { - continue; - } - // TODO: Same backward traversal from tv is done for all input - // tvs. Consider doing the analysis one for all inputs - if (!areAllInputIdsMappedTo(input_tv, tv)) { - return false; - } - } - return true; +TensorView* getReferenceTensor(Fusion* fusion) { + FusionGuard fg(fusion); + PointwiseDomainMap domain_map(fusion); + auto reference_tv = domain_map.findReferenceTensor(); + return reference_tv; } } // namespace pointwise_utils diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 56db0ee0806..cc9a43d5c0f 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -11,56 +11,12 @@ #include #include #include +#include #include namespace nvfuser { namespace pointwise_utils { -// DomainMap uses the ComputeAtMap to find a reference TensorView -// that maps to all IterDomains in the fusion. -class DomainMap { - public: - DomainMap(Fusion* fusion); - virtual ~DomainMap() = default; - - const ComputeAtMap& getComputeAtMap() const { - return ca_map_; - } - - // Determine if a TensorView is a valid reference tensor for this fusion. - // The reference tensor must map to all the iterDomains in each input. - bool isValidReference(TensorView* tv) const; - - protected: - // Determine if all IterDomains are mapped between input and the given tvs - bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) - const; - - virtual IterDomain* getMappedInputConcreteID( - const std::unordered_set& in_concrete_ids, - IterDomain* out_id) const; - - // Erase input concrete ID if it is mapped to output ID - bool eraseIfMapped( - std::unordered_set& in_concrete_ids, - IterDomain* out_id) const; - - // Check if in_ids are mapped to ids through any root domain as - // well as indirectly accessed domains with ops like torchGather - void eraseifInputMappedThroughRootDomainAndIndexing( - std::unordered_set& in_ids, - const std::vector& ids) const; - - // Find any id in domain that maps with target id - IterDomain* anyMapped( - const std::vector& domain, - IterDomain* target) const; - - Fusion* fusion_ = nullptr; - ComputeAtMap ca_map_; - std::vector tvs_with_rfactor_; -}; - // Returns number of non-reduction/non-broadcas/non-device dims in logical // domain inline int64_t nRootDims(const TensorView* tv) { @@ -74,5 +30,23 @@ inline int64_t nRootDims(const TensorView* tv) { return tv_n_dims; } +class PointwiseDomainMap : public scheduler_tools::DomainMap { + public: + using scheduler_tools::DomainMap::DomainMap; + + // The pointwise scheduler heuristics requires a minimum number of axes. + // The output reference tensor should respect this requirement. + TensorView* findReferenceTensor(int64_t minimum_num_axes = 0) const; + + private: + bool hasMinimumSize(TensorView* tv, int64_t num_axes) const { + NVF_ERROR(tv != nullptr); + return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes); + } +}; + +// Return reference tensor view. +TensorView* getReferenceTensor(Fusion* fusion); + } // namespace pointwise_utils } // namespace nvfuser diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index fd9573d6b32..039d94b38af 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -90,6 +91,8 @@ std::unique_ptr SchedulerEntry::makeSchedulerInstance( return std::make_unique(); case SchedulerType::ExprEval: return std::make_unique(); + case SchedulerType::Resize: + return std::make_unique(); default: NVF_THROW("unreachable"); } diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp new file mode 100644 index 00000000000..fc96bd3db67 --- /dev/null +++ b/csrc/scheduler/resize.cpp @@ -0,0 +1,219 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser { + +namespace { + +// Just use the pointwise version for now +TensorView* getReferenceTensor(Fusion* fusion) { + return pointwise_utils::getReferenceTensor(fusion); +} + +} // namespace + +bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { + if (!isOptionEnabled(EnableOption::ResizeScheduler)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Not enabled"); + return false; + } + + if (!ir_utils::hasOpsOfType(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "No resize op to schedule"); + return false; + } + + if (scheduler_utils::isResharding(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Fusion is resharding."); + return false; + } + + if (ir_utils::hasAnyReductionOps(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "No support for reduction ops"); + return false; + } + + if (registry_utils::hasNonUniqueBcast(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Broadcasting dimension might be broadcasting to multiple sizes."); + return false; + } + + // For now, the resize scheduler is only allowed for a limited set + // of fusion patterns. The restrictions are planned to be + // incrementally relaxed. + + IdModel id_model(fusion, /*build_graphs=*/false); + const auto& broadcast_graph = id_model.buildBroadcastGraph(); + + // For now, only a single resize op is allowed to exist. + auto resize_based_tensor_ops = ir_utils::getOpsOfType(fusion); + if (resize_based_tensor_ops.size() != 1) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Only a single resize op is allowed."); + return false; + } + + auto resize_out_tv = + resize_based_tensor_ops.at(0)->output(0)->as(); + + auto all_dep_vals = DependencyCheck::getAllValsBetween( + {fusion->inputs().begin(), fusion->inputs().end()}, {resize_out_tv}); + for (auto tv : ir_utils::filterByType(all_dep_vals)) { + if (tv == resize_out_tv) { + continue; + } + if (tv->isFusionOutput()) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Dependency to fusion output not allowed: ", + tv->toString()); + return false; + } + for (auto consumer_of_tv : ir_utils::consumerTvsOf(tv)) { + if (std::find(all_dep_vals.begin(), all_dep_vals.end(), consumer_of_tv) == + all_dep_vals.end()) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Resize inputs must be exclusively consumed by resize: ", + consumer_of_tv->toString()); + return false; + } + } + } + + // Slicing of or to a broadcast ID is not allowed yet. + for (auto tensor_op : resize_based_tensor_ops) { + TensorView* out_tv = tensor_op->output(0)->as(); + for (auto logical_id : out_tv->getLogicalDomain()) { + Resize* resize = dynamic_cast(logical_id->definition()); + if (resize == nullptr) { + continue; + } + + if (resize->out()->isBroadcast()) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Resize to a broadcast ID is not allowed."); + return false; + } + + // Need to check the broadcast group rather than just the input + // ID only. For example, + // + // t0: [i0] + // t1: [b1] + // t2 = t0 + t1 + // t3 = slice(t2) + // + // Then, propagating the slice to its inputs would try to + // propagate the resize op to b1 as well, which would fail due + // to issue #3571 + const auto& input_group = broadcast_graph.toGroup(resize->in()); + if (std::any_of( + input_group->begin(), input_group->end(), [](Val* inp_val) { + return inp_val->as()->isBroadcast(); + })) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Resize of a broadcast ID is not allowed."); + return false; + } + } + } + + // This doesn't work yet due to issue #3571 + auto ref_tv = getReferenceTensor(fusion); + if (std::any_of( + ref_tv->getLogicalDomain().begin(), + ref_tv->getLogicalDomain().end(), + [](IterDomain* logical_id) { return logical_id->isBroadcast(); })) { + return false; + } + + // Disable the scheduler if there's a squeeze op. The loop option + // may also need to be enabled in that case, but that option is not + // turned on automatically yet. + if (ir_utils::hasOpsOfType(fusion)) { + return false; + } + + return true; +} + +std::unique_ptr ResizeScheduler::computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache) { + FUSER_PERF_SCOPE("ResizeScheduler::computeHeuristics"); + auto params = std::make_unique(SchedulerType::Resize); + params->cparams.index_type = runtime_info.getIndexType(); + return params; +} + +void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { + FUSER_PERF_SCOPE("ResizeScheduler::schedule"); + + FusionGuard fg(fusion); + + scheduler_utils::clearMemorySpace(fusion); + + scheduler_utils::cacheInputs(fusion, true); + scheduler_utils::cacheAndForkOutputs(fusion, true); + + for (auto expr : fusion->exprs()) { + if (!expr->isOneOf()) { + continue; + } + + scheduler_tools::propagateResizeToInputs(expr); + } + + auto ref_tv = getReferenceTensor(fusion); + + // Just simple scheduling for now. + // TODO: Do something smarter. Can just use the pointwise scheduler? + + // Make sure the DID ID located at the outermost position + const auto outermost_pos = scheduler_utils::reorderDevicesToOuter(ref_tv); + + // Schedule only the remaining IDs + ref_tv->flatten(outermost_pos); + ref_tv->split(outermost_pos, 128); + ref_tv->split(outermost_pos, 1 << 14); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + // Propagate the reference to the other tensors + scheduler_tools::scheduleLoopDomainsLike( + fusion->allTvs(), ref_tv->getLoopDomain()); + + inlineMost(); + + markAliases(fusion); +} + +} // namespace nvfuser diff --git a/csrc/scheduler/resize.h b/csrc/scheduler/resize.h new file mode 100644 index 00000000000..b51ecf1e6dd --- /dev/null +++ b/csrc/scheduler/resize.h @@ -0,0 +1,41 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { + +class Fusion; +class SchedulerRuntimeInfo; +class HeuristicDataCache; + +class ResizeScheduler : public SchedulerEntry { + public: + bool canScheduleCompileTime(Fusion* fusion) override; + bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache = nullptr) override { + return true; + } + + std::unique_ptr computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache) override; + + void schedule(Fusion* fusion, const HeuristicParams* params) override; + + constexpr static SchedulerType schedulerType() { + return SchedulerType::Resize; + } +}; + +} // namespace nvfuser diff --git a/csrc/scheduler/scheduler_types.cpp b/csrc/scheduler/scheduler_types.cpp index 623d5a22697..cf9b974acf5 100644 --- a/csrc/scheduler/scheduler_types.cpp +++ b/csrc/scheduler/scheduler_types.cpp @@ -31,6 +31,8 @@ std::string toString(SchedulerType scheduler_type) { return "matmul"; case SchedulerType::ExprEval: return "expr_eval"; + case SchedulerType::Resize: + return "resize"; case SchedulerType::None: return "none"; default: diff --git a/csrc/scheduler/scheduler_types.h b/csrc/scheduler/scheduler_types.h index 275a1f372e7..caa389abb9a 100644 --- a/csrc/scheduler/scheduler_types.h +++ b/csrc/scheduler/scheduler_types.h @@ -56,15 +56,17 @@ enum class SchedulerType { InnerOuterPersistent, OuterPersistent, Transpose, - ExprEval + ExprEval, + Resize }; //! Define a schedule table to loop over all the heuristics in priority order. -constexpr std::array all_heuristics_in_priority_order = { +constexpr std::array all_heuristics_in_priority_order = { SchedulerType::ExprEval, SchedulerType::NoOp, SchedulerType::Matmul, SchedulerType::Reduction, + SchedulerType::Resize, SchedulerType::Transpose, SchedulerType::PointWise, SchedulerType::InnerPersistent, diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp new file mode 100644 index 00000000000..0a713d346a0 --- /dev/null +++ b/csrc/scheduler/tools/domain_map.cpp @@ -0,0 +1,247 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +namespace nvfuser { +namespace scheduler_tools { + +namespace { +// Check if a root ID of a fusion input tensor that is indirectly +// accessed by ops such as torchGather needs to be mapped with +// a reference tensor. Select has a similar effect as squeeze as the +// indexed domain is removed, so the domain does not need to be mapped +// as long as the tensor is a fusion input. Similarly, in index_select +// and torchGather, if the output domain is a broadcast, it does not +// need to be mapped if not resolved. +bool canIgnoreIndexedInputDomainID( + TensorView* input_tv, + IterDomain* root_id, + const ComputeAtMap& ca_map) { + NVF_ERROR(input_tv->isFusionInput()); + for (auto use : input_tv->uses()) { + if (auto select = dynamic_cast(use)) { + if (root_id != select->getIndexedID()) { + return false; + } + } else if (auto index_select = dynamic_cast(use)) { + // If the root_id is an indexed ID, and the consumer ID may be a + // broadcast. In that case, nothing needs to be mapped if the + // consumer broadcast is not resolved + if (root_id != index_select->getIndexedID() || + !ca_map + .getConcreteMappedID( + index_select->getConsumerOfIndexedID(), + IdMappingMode::PERMISSIVE) + ->isBroadcast()) { + return false; + } + } else if (auto gather = dynamic_cast(use)) { + // TODO: Remove this. Once slice is used for torchGather, this + // should not be necessary. For now, it is necessary to not + // break the existing torchGather tests + if (!gather->exactSizes()) { + continue; + } + // If the root_id is an indexed ID, and the consumer ID may be a + // broadcast. In that case, nothing needs to be mapped if the + // consumer broadcast is not resolved + if (root_id != gather->getIndexedID() || + !ca_map + .getConcreteMappedID( + gather->getConsumerOfIndexedID(), IdMappingMode::PERMISSIVE) + ->isBroadcast()) { + return false; + } + } else { + // If the input TV is used by any other ops + return false; + } + } + + return true; +} + +// Grab all exact set mappings from consumer to producer domains of +// indexed accesses, e.g., index_select +std::unordered_multimap< + std::shared_ptr>, + std::shared_ptr>> +getIndexedConsumerToProducerMap(Fusion* fusion, const ComputeAtMap& ca_map) { + std::unordered_multimap< + std::shared_ptr>, + std::shared_ptr>> + indexed_id_map; + + for (auto expr : fusion->exprs()) { + if (auto gather = dynamic_cast(expr)) { + auto p_id = gather->getIndexedID(); + auto c_id = gather->getConsumerOfIndexedID(); + indexed_id_map.emplace( + ca_map.disjointSetOf(c_id, IdMappingMode::EXACT), + ca_map.disjointSetOf(p_id, IdMappingMode::EXACT)); + } else if (auto index_select = dynamic_cast(expr)) { + auto p_id = index_select->getIndexedID(); + auto c_id = index_select->getConsumerOfIndexedID(); + indexed_id_map.emplace( + ca_map.disjointSetOf(c_id, IdMappingMode::EXACT), + ca_map.disjointSetOf(p_id, IdMappingMode::EXACT)); + } else { + // Note there's no consumer ID for select. This means we can't + // just propagate from consumers to indexed producers. It seems + // it's necessary to schedule producers and consumers separately + // in those cases. + continue; + } + } + + return indexed_id_map; +} + +} // namespace + +DomainMap::DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(fusion) { + tvs_with_rfactor_ = scheduler_utils::getTVsWithNonReductionRFactor(fusion); +} + +// Determine if all IterDomains in input are mapped to the given tensor +bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) + const { + // Get concrete IDs for input root or logical domain + std::unordered_set in_concrete_ids; + for (auto in_id : input_tv->getLogicalDomain()) { + if (canIgnoreIndexedInputDomainID(input_tv, in_id, ca_map_)) { + continue; + } + + // Permissive map is required for the transpose scheduler to support cases + // like T0[I0, b] + T1[b, I1] + auto concrete = + ca_map_.getConcreteMappedID(in_id, IdMappingMode::PERMISSIVE); + + if (!concrete->isBroadcast() && !in_id->isReduction()) { + in_concrete_ids.insert(concrete); + } + } + + // Erase all input concrete IDs mapped to the output domain + // Ignore unresolved broadcast dimensions + eraseifInputMappedThroughRootDomainAndIndexing( + in_concrete_ids, tv->getLogicalDomain()); + + return in_concrete_ids.empty(); +} + +// Reference domains must exactly match with the input domains. See +// also PR #661 +IterDomain* DomainMap::getMappedInputConcreteID( + const std::unordered_set& in_concrete_ids, + IterDomain* out_id) const { + auto in_concrete_id_iter = std::find_if( + in_concrete_ids.begin(), + in_concrete_ids.end(), + [&](IterDomain* in_concrete_id) { + return ca_map_.areMapped(in_concrete_id, out_id, IdMappingMode::EXACT); + }); + if (in_concrete_id_iter != in_concrete_ids.end()) { + return *in_concrete_id_iter; + } else { + return nullptr; + } +} + +// Erase input concrete ID if it is mapped to output ID +bool DomainMap::eraseIfMapped( + std::unordered_set& in_concrete_ids, + IterDomain* out_id) const { + auto mapped_input_conrete_id = + getMappedInputConcreteID(in_concrete_ids, out_id); + if (mapped_input_conrete_id != nullptr) { + in_concrete_ids.erase(mapped_input_conrete_id); + return true; + } else { + return false; + } +} + +void DomainMap::eraseifInputMappedThroughRootDomainAndIndexing( + std::unordered_set& in_ids, + const std::vector& ids) const { + // Use ComputeAtMap::getAllDisjointSetProducers to grab all producer + // IDs through rfactor exprs + VectorOfUniqueEntries>> + exact_sets; + std::for_each(ids.begin(), ids.end(), [&](IterDomain* id) { + exact_sets.pushBack(ca_map_.disjointSetOf(id, IdMappingMode::EXACT)); + }); + + // Traverse through indexed domains. + const auto indexed_id_multimap = + getIndexedConsumerToProducerMap(fusion_, ca_map_); + + VectorOfUniqueEntries>> + all_exact_sets_covered; + + // Back traverses through the exact map and indexed + // producer-consumer pairs + for (auto current_sets = exact_sets; !current_sets.empty();) { + auto producer_sets = ca_map_.getAllDisjointSetProducers(current_sets); + all_exact_sets_covered.pushBack(producer_sets); + + current_sets.clear(); + + // Further traversal if any of the new producer sets is a producer + // of indexed domains + for (const auto& producer_set : producer_sets) { + auto indexed_id_multimap_range = + indexed_id_multimap.equal_range(producer_set); + for (auto producer_of_producer_it = indexed_id_multimap_range.first; + producer_of_producer_it != indexed_id_multimap_range.second; + ++producer_of_producer_it) { + current_sets.pushBack(producer_of_producer_it->second); + } + } + } + + for (const auto& exact_set_ptr : all_exact_sets_covered) { + auto exact_concrete_id = ca_map_.getConcreteMappedID( + exact_set_ptr->front(), IdMappingMode::EXACT); + eraseIfMapped(in_ids, exact_concrete_id); + } +} + +// Find any id in domain that maps with target id +IterDomain* DomainMap::anyMapped( + const std::vector& domain, + IterDomain* target) const { + for (auto id : domain) { + if (ca_map_.areMapped(id, target, IdMappingMode::EXACT)) { + return id; + } + } + return nullptr; +} + +// Determine if output TensorView is a valid reference tensor for this fusion. +// The reference tensor must map to all the iterDomains in each input. +bool DomainMap::isValidReference(TensorView* tv) const { + for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { + if (input_tv->uses().empty()) { + continue; + } + // TODO: Same backward traversal from tv is done for all input + // tvs. Consider doing the analysis one for all inputs + if (!areAllInputIdsMappedTo(input_tv, tv)) { + return false; + } + } + return true; +} + +} // namespace scheduler_tools +} // namespace nvfuser diff --git a/csrc/scheduler/tools/domain_map.h b/csrc/scheduler/tools/domain_map.h new file mode 100644 index 00000000000..88dadcba721 --- /dev/null +++ b/csrc/scheduler/tools/domain_map.h @@ -0,0 +1,69 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include +#include + +namespace nvfuser { + +class Fusion; +class TensorView; +class IterDomain; + +namespace scheduler_tools { + +// DomainMap uses the ComputeAtMap to find a reference TensorView +// that maps to all IterDomains in the fusion. +class DomainMap { + public: + DomainMap(Fusion* fusion); + virtual ~DomainMap() = default; + + const ComputeAtMap& getComputeAtMap() const { + return ca_map_; + } + + // Determine if a TensorView is a valid reference tensor for this fusion. + // The reference tensor must map to all the iterDomains in each input. + bool isValidReference(TensorView* tv) const; + + protected: + // Determine if all IterDomains are mapped between input and the given tvs + bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) + const; + + virtual IterDomain* getMappedInputConcreteID( + const std::unordered_set& in_concrete_ids, + IterDomain* out_id) const; + + // Erase input concrete ID if it is mapped to output ID + bool eraseIfMapped( + std::unordered_set& in_concrete_ids, + IterDomain* out_id) const; + + // Check if in_ids are mapped to ids through any root domain as + // well as indirectly accessed domains with ops like torchGather + void eraseifInputMappedThroughRootDomainAndIndexing( + std::unordered_set& in_ids, + const std::vector& ids) const; + + // Find any id in domain that maps with target id + IterDomain* anyMapped( + const std::vector& domain, + IterDomain* target) const; + + Fusion* fusion_ = nullptr; + ComputeAtMap ca_map_; + std::vector tvs_with_rfactor_; +}; + +} // namespace scheduler_tools +} // namespace nvfuser diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 7e320f99a91..95100c7fc5f 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -155,9 +155,9 @@ bool hasSmallTransposeDimensions( // DomainMap uses the ComputeAtMap to find a reference TensorView // that maps to all iterDomains in the fusion. -class DomainMap : public pointwise_utils::DomainMap { +class TransposeDomainMap : public scheduler_tools::DomainMap { public: - using pointwise_utils::DomainMap::DomainMap; + using scheduler_tools::DomainMap::DomainMap; // Note that this may not be able to find any reference if any // tensor in the group is only connected with an input through @@ -196,7 +196,7 @@ class DomainMap : public pointwise_utils::DomainMap { static bool hasAtLeastTwoValidGroups(Fusion* fusion) { FusionGuard fg(fusion); - DomainMap domain_map(fusion); + TransposeDomainMap domain_map(fusion); auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); if (grouped_inputs_outputs.size() < 2) { return false; @@ -560,12 +560,14 @@ HeuristicDataCacheEntry getDomainMap( auto domain_map_entry = HeuristicDataCacheEntry( data_cache, - [fusion]() { return std::make_unique(fusion); }); + [fusion]() { return std::make_unique(fusion); }); return domain_map_entry; } HeuristicDataCacheEntry -getInputsOutputsGroups(HeuristicDataCache* data_cache, DomainMap& domain_map) { +getInputsOutputsGroups( + HeuristicDataCache* data_cache, + TransposeDomainMap& domain_map) { auto grouped_inputs_outputs_entry = HeuristicDataCacheEntry< HeuristicCompileTime::InputsOutputsInnerDimGroups>( data_cache, [&domain_map]() { @@ -584,7 +586,7 @@ getInputsOutputsGroups(HeuristicDataCache* data_cache, DomainMap& domain_map) { HeuristicDataCacheEntry getReferenceTensors( HeuristicDataCache* data_cache, - DomainMap& domain_map, + TransposeDomainMap& domain_map, std::vector>& grouped_inputs_outputs) { auto reference_tensors_entry = HeuristicDataCacheEntry( @@ -609,7 +611,7 @@ std::pair, int64_t> getShapeInReference( HeuristicDataCache* data_cache, SchedulerRuntimeInfo& runtime_info, TensorView* reference, - DomainMap& domain_map) { + TransposeDomainMap& domain_map) { auto ref_logical = reference->getLogicalDomain(); std::vector shape_in_ref; shape_in_ref.reserve(reference->nDims()); @@ -635,7 +637,7 @@ getInnerMostDimInfoInReference( HeuristicDataCache* data_cache, const std::vector& group_references, TensorView* global_reference, - DomainMap& domain_map) { + TransposeDomainMap& domain_map) { auto innermost_info_entry = HeuristicDataCacheEntry( data_cache, [&]() { @@ -659,7 +661,7 @@ std::string getTransposeRuntimeRejectReason( HeuristicDataCache* data_cache, SchedulerRuntimeInfo& runtime_info) { auto domain_map_entry = getDomainMap(data_cache, fusion); - auto& domain_map = dynamic_cast(domain_map_entry.get()); + auto& domain_map = dynamic_cast(domain_map_entry.get()); auto grouped_inputs_outputs_entry = getInputsOutputsGroups(data_cache, domain_map); auto grouped_inputs_outputs = grouped_inputs_outputs_entry.get(); @@ -789,7 +791,7 @@ std::string getTransposeRuntimeRejectReason( } // namespace bool hasAtLeastTwoValidGroups(Fusion* fusion) { - return DomainMap::hasAtLeastTwoValidGroups(fusion); + return TransposeDomainMap::hasAtLeastTwoValidGroups(fusion); } std::unique_ptr getTransposeHeuristics( @@ -802,7 +804,7 @@ std::unique_ptr getTransposeHeuristics( const auto index_type = runtime_info.getIndexType(); auto domain_map_entry = getDomainMap(data_cache, fusion); - auto& domain_map = dynamic_cast(domain_map_entry.get()); + auto& domain_map = dynamic_cast(domain_map_entry.get()); auto grouped_inputs_outputs_entry = getInputsOutputsGroups(data_cache, domain_map); auto grouped_inputs_outputs = grouped_inputs_outputs_entry.get(); @@ -1057,7 +1059,7 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { return; } - DomainMap domain_map(fusion); + TransposeDomainMap domain_map(fusion); auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); NVF_ERROR(grouped_inputs_outputs.size() >= 2); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 667ea7f40f2..79920ec96c7 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2661,6 +2661,19 @@ void moveNonConcretizedBroadcastInnermost( } } +int64_t reorderDevicesToOuter(TensorView* tv) { + int64_t reorder_pos = 0; + std::unordered_map old2new; + for (const auto i : c10::irange(tv->getLoopDomain().size())) { + if (tv->axis((int64_t)i)->isDeviceDim()) { + old2new.emplace((int64_t)i, reorder_pos); + ++reorder_pos; + } + } + tv->reorder(old2new); + return (int64_t)old2new.size(); +} + } // namespace scheduler_utils } // namespace nvfuser diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 77317cde31b..29f7f12efc6 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -729,5 +729,9 @@ void moveNonConcretizedBroadcastInnermost( Fusion* fusion, const std::unordered_set& ignored_tvs = {}); +// Reorder DID parallelized axes to outermost positions. Returns +// the position of the outermost non-DID axis. +int64_t reorderDevicesToOuter(TensorView* tv); + } // namespace scheduler_utils } // namespace nvfuser diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index e9fb2dcf503..630c77ac458 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -959,34 +959,6 @@ TEST_F(AliasTest, SourceIsBothInputAndOutput) { EXPECT_EQ(in_tensor.data_ptr(), out_tensors[1].data_ptr()); } -TEST_F(AliasTest, SegmentBoundary) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - TensorView* in = makeContigConcreteTensor({2, 3}); - TensorView* out = permute(in, {1, 0}); - // With the current segmentation algorithm, `slice` has to be the start of a - // fusion. So we expect `permute` to form a meta-op-only segment and the rest - // a pointwise segment. - out = slice(out, {0, 0}, {2, 2}); - out = add(out, out); - fusion->addInput(in); - fusion->addOutput(out); - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor in_tensor = at::randn({2, 3}).cuda(); - at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; - testValidate( - executor_cache.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); - - FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - EXPECT_THAT( - runtime->fusionSegments()->groups(), - UnorderedElementsAre( - HeuristicIs(SchedulerType::NoOp), - HeuristicIs(SchedulerType::PointWise))); -} - TEST_F(AliasTest, ReuseBuffer) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 8fa235357b1..66087eab2f5 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -8051,23 +8051,27 @@ TEST_F(NVFuserTest, AvoidCachingSliceInput) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(inputs); - // check segment and sliced tvs are not cached + // check segmentation and sliced tvs are not cached if not scheduled by + // the resize scheduler auto kernel_runtime = executor_cache.getMostRecentKernelRuntime(); - NVF_CHECK(kernel_runtime->isSegmented(), "segmentation didn't happen"); const auto num_segments = kernel_runtime->fusionSegments()->groups().size(); - NVF_CHECK(num_segments == 3, "Expect 3 segments, got: ", num_segments); - for (const auto& exec : kernel_runtime->executors()) { + EXPECT_EQ(num_segments, 3) << "Expect 3 segments, got: " << num_segments; + for (const auto i : c10::irange(kernel_runtime->executors().size())) { + const auto& exec = kernel_runtime->executors().at(i); if (!exec->isA()) { continue; } + if (kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(i) + ->scheduler_type == SchedulerType::Resize) { + continue; + } const auto* ke = exec->as(); for (auto expr : ke->fusion()->exprs()) { if (expr->isA()) { auto slice = expr->as(); - NVF_CHECK( - slice->in()->getMemoryType() == MemoryType::Global, - "slice input must be in global memory, get: ", - slice->in()->getMemoryType()); + EXPECT_EQ(slice->in()->getMemoryType(), MemoryType::Global); } } } diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 0b7e816cc46..fb6c74512f2 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -55,7 +55,27 @@ void checkLoopDomainEquivalence( } // namespace -using ResizeTest = NVFuserTest; +class ResizeTest : public NVFuserTest { + protected: + void SetUp() override { + EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); + NVFuserTest::SetUp(); + } + + private: + EnableOptionsGuard enable_options_guard_; +}; + +class ResizeSchedulerTest : public NVFuserFixtureParamTest { + protected: + void SetUp() override { + EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); + NVFuserFixtureParamTest::SetUp(); + } + + private: + EnableOptionsGuard enable_options_guard_; +}; using testing::Each; using testing::HasSubstr; @@ -64,6 +84,14 @@ using testing::Property; using testing::ThrowsMessage; using testing::UnorderedElementsAre; +INSTANTIATE_TEST_SUITE_P( + , + ResizeSchedulerTest, + testing::Bool(), + [](const testing::TestParamInfo& info) { + return info.param ? "Scheduler" : "Manual"; + }); + // Simple pad test TEST_F(ResizeTest, Pad1) { Fusion fusion; @@ -2055,7 +2083,10 @@ TEST_F(ResizeTest, ResizeReshapeAndSlice) { } // Make sure resize works with the transpose scheduler -TEST_F(ResizeTest, ResizePermuteAndSlice) { +// This is consumed by the resize scheduler. We should extend the +// transpose scheduler to support resize without the segment-input +// requirement. +TEST_F(ResizeTest, DISABLED_ResizePermuteAndSlice) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -3046,11 +3077,9 @@ TEST_F(ResizeTest, ReshapeToPad) { auto outputs = executor_cache.runFusionWithInputs(aten_inputs); - // Assert that we segmented into two segments auto seg_fusion = executor_cache.getMostRecentKernelRuntime()->fusionSegments(); - EXPECT_TRUE(seg_fusion->isSegmented()); - EXPECT_EQ(seg_fusion->groups().size(), 2); + EXPECT_EQ(seg_fusion->groups().size(), 1); testValidate( executor_cache.fusion(), @@ -3990,9 +4019,10 @@ TEST_F(ResizeTest, SliceSliceConcatConcat) { } // Consumer-based scheduling of slice -TEST_F(ResizeTest, PropagateSliceToInputs) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateSliceToInputs) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4002,64 +4032,85 @@ TEST_F(ResizeTest, PropagateSliceToInputs) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + // Dont't use set here as it gets taken by the no-op scheduler + auto tv1 = sin(tv0); auto tv2 = slice( tv1, {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, {IrBuilder::create(1L), IrBuilder::create(99)}}); - auto tv3 = set(tv2); + auto tv3 = cos(tv2); fusion.addOutput(tv3); - scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); - auto ref_tv = tv3; + const bool use_scheduler = GetParam(); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + if (!use_scheduler) { + scheduler_tools::propagateResizeToInputs(tv2->definition()); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + auto ref_tv = tv3; - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; - } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + inlineMost(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - std::vector inputs({t0}); + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // Propagating slice to inputs with reshape before slice -TEST_F(ResizeTest, PropagateSliceToInputsWithReshape1) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateSliceToInputsWithReshape1) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({16, 100}); @@ -4069,7 +4120,7 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape1) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = reshape(tv1, shape, {16, 5, 20}); @@ -4079,57 +4130,77 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape1) { {fusion.zeroVal(), tv2->getLogicalDomain().at(1)->extent()}, {IrBuilder::create(1L), IrBuilder::create(10)}}); - auto tv4 = set(tv3); + auto tv4 = cos(tv3); fusion.addOutput(tv4); - scheduler_tools::propagateResizeToInputs(tv3->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector inputs({t0}); - auto ref_tv = tv4; + const bool use_scheduler = GetParam(); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + if (!use_scheduler) { + scheduler_tools::propagateResizeToInputs(tv3->definition()); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + auto ref_tv = tv4; - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; - } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + inlineMost(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - std::vector inputs({t0}); + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // Propagating slice to inputs with reshape after slice -TEST_F(ResizeTest, PropagateSliceToInputsWithReshape2) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateSliceToInputsWithReshape2) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({16, 100}); @@ -4139,7 +4210,7 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape2) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = slice( tv1, @@ -4148,53 +4219,73 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape2) { auto tv3 = reshape(tv2, {shape[0], 49}, {shape[0] * 49}); - auto tv4 = set(tv3); + auto tv4 = cos(tv3); fusion.addOutput(tv4); - scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector inputs({t0}); - auto ref_tv = tv4; + const bool use_scheduler = GetParam(); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + if (!use_scheduler) { + scheduler_tools::propagateResizeToInputs(tv2->definition()); - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + auto ref_tv = tv4; - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; - } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + inlineMost(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - std::vector inputs({t0}); + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } -TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateMultipleSlicesToInputs) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4204,7 +4295,7 @@ TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = slice( tv1, @@ -4216,69 +4307,91 @@ TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) { {{fusion.zeroVal(), tv2->getLogicalDomain().at(0)->extent()}, {IrBuilder::create(1L), tv2->getLogicalDomain().at(1)->extent()}}); - auto tv4 = set(tv3); + auto tv4 = cos(tv3); fusion.addOutput(tv4); - // Propagate the first slice to tv1 - scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); - // Propagate the second slice to tv1 and tv2 - scheduler_tools::propagateResizeToInputs(tv3->definition()); + const bool use_scheduler = GetParam(); - // Each of tv1 and tv2 has two resize ops. - for (auto tv : {tv1, tv2}) { - auto resize1 = dynamic_cast(tv->axis(-1)->definition()); - EXPECT_NE(resize1, nullptr); - auto resize2 = dynamic_cast(resize1->in()->definition()); - EXPECT_NE(resize2, nullptr) << tv->toString(); - } + if (!use_scheduler) { + // Propagate the first slice to tv1 + scheduler_tools::propagateResizeToInputs(tv2->definition()); - auto ref_tv = tv4; + // Propagate the second slice to tv1 and tv2 + scheduler_tools::propagateResizeToInputs(tv3->definition()); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Each of tv1 and tv2 has two resize ops. + for (auto tv : {tv1, tv2}) { + auto resize1 = dynamic_cast(tv->axis(-1)->definition()); + EXPECT_NE(resize1, nullptr); + auto resize2 = dynamic_cast(resize1->in()->definition()); + EXPECT_NE(resize2, nullptr) << tv->toString(); + } - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + auto ref_tv = tv4; - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - std::vector inputs({t0}); + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + GTEST_SKIP() << "Scheduling not yet supported"; - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // RoPE-like rotation patten -TEST_F(ResizeTest, SliceRotateCat) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, SliceRotateCat) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4288,14 +4401,14 @@ TEST_F(ResizeTest, SliceRotateCat) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = slice( tv1, {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); - auto tv3 = set(tv0); + auto tv3 = sin(tv0); auto tv4 = slice( tv3, @@ -4305,75 +4418,97 @@ TEST_F(ResizeTest, SliceRotateCat) { auto tv5 = cat({tv4, tv2}, 1); - fusion.addOutput(tv5); - - // Propagate the left half of slice and pad - scheduler_tools::propagateResizeToInputs(tv2->definition()); - auto pad_left = - dynamic_cast(tv5->definition()->input(0)->definition()); - scheduler_tools::propagateResizeToInputs(pad_left); - - // Propagate the right half of slice and pad - scheduler_tools::propagateResizeToInputs(tv4->definition()); - auto pad_right = - dynamic_cast(tv5->definition()->input(1)->definition()); - scheduler_tools::propagateResizeToInputs(pad_right); - - auto ref_tv = tv5; - - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + fusion.addOutput(tv5); - { - IdModel id_model(&fusion, false); - id_model.buildExactGraph(); - std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); - auto dot_string = - id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); - ofs << dot_string; - ofs.close(); - } + const bool use_scheduler = GetParam(); + + if (!use_scheduler) { + // Propagate the left half of slice and pad + scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto pad_left = + dynamic_cast(tv5->definition()->input(0)->definition()); + scheduler_tools::propagateResizeToInputs(pad_left); + + // Propagate the right half of slice and pad + scheduler_tools::propagateResizeToInputs(tv4->definition()); + auto pad_right = + dynamic_cast(tv5->definition()->input(1)->definition()); + scheduler_tools::propagateResizeToInputs(pad_right); + + auto ref_tv = tv5; + + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); + + { + IdModel id_model(&fusion, false); + id_model.buildExactGraph(); + std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); + auto dot_string = + id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); + ofs << dot_string; + ofs.close(); + } - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - inlineMost(); + inlineMost(); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - std::vector inputs({t0}); + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + GTEST_SKIP() << "Scheduling not yet supported"; - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // RoPE-like rotation and residual patten -TEST_F(ResizeTest, SliceRotateCatResidual) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, SliceRotateCatResidual) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4383,14 +4518,14 @@ TEST_F(ResizeTest, SliceRotateCatResidual) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = slice( tv1, {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); - auto tv3 = set(tv0); + auto tv3 = sin(tv0); auto tv4 = slice( tv3, @@ -4404,74 +4539,96 @@ TEST_F(ResizeTest, SliceRotateCatResidual) { fusion.addOutput(tv6); - // Propagate the left half of slice and pad - scheduler_tools::propagateResizeToInputs(tv2->definition()); - auto pad_left = - dynamic_cast(tv5->definition()->input(1)->definition()); - scheduler_tools::propagateResizeToInputs(pad_left); - - // Propagate the right half of slice and pad - scheduler_tools::propagateResizeToInputs(tv4->definition()); - auto pad_right = - dynamic_cast(tv5->definition()->input(0)->definition()); - scheduler_tools::propagateResizeToInputs(pad_right); - - auto ref_tv = tv6; - - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); - - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); - { - IdModel id_model(&fusion, false); - id_model.buildExactGraph(); - std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); - auto dot_string = - id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); - ofs << dot_string; - ofs.close(); - } + const bool use_scheduler = GetParam(); + + if (!use_scheduler) { + // Propagate the left half of slice and pad + scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto pad_left = + dynamic_cast(tv5->definition()->input(1)->definition()); + scheduler_tools::propagateResizeToInputs(pad_left); + + // Propagate the right half of slice and pad + scheduler_tools::propagateResizeToInputs(tv4->definition()); + auto pad_right = + dynamic_cast(tv5->definition()->input(0)->definition()); + scheduler_tools::propagateResizeToInputs(pad_right); + + auto ref_tv = tv6; + + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); + + { + IdModel id_model(&fusion, false); + id_model.buildExactGraph(); + std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); + auto dot_string = + id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); + ofs << dot_string; + ofs.close(); + } - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - inlineMost(); + inlineMost(); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()) + << "Invalid computeAt position of " << tv->toString(); } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()) - << "Invalid computeAt position of " << tv->toString(); - } - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - std::vector inputs({t0}); + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + GTEST_SKIP() << "Scheduling not yet supported"; - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // Consumer-based scheduling of pad -TEST_F(ResizeTest, PropagatePadToInputs) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagatePadToInputs) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4481,61 +4638,81 @@ TEST_F(ResizeTest, PropagatePadToInputs) { auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); - auto tv1 = set(tv0); + auto tv1 = sin(tv0); auto tv2 = pad(tv1, {fusion.oneVal(), IrBuilder::create(2L)}); - auto tv3 = set(tv2); + auto tv3 = cos(tv2); fusion.addOutput(tv3); - scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); - auto ref_tv = tv3; + const bool use_scheduler = GetParam(); - // Fusion should have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + if (!use_scheduler) { + scheduler_tools::propagateResizeToInputs(tv2->definition()); - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + auto ref_tv = tv3; - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - inlineMost(); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; - } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + inlineMost(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - std::vector inputs({t0}); + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // Consumer-based scheduling of cat -TEST_F(ResizeTest, PropagateCatToInputs) { - Fusion fusion; - FusionGuard fg(&fusion); +TEST_P(ResizeSchedulerTest, PropagateCatToInputs) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); std::vector shape({-1, 100}); @@ -4547,65 +4724,86 @@ TEST_F(ResizeTest, PropagateCatToInputs) { auto tv1 = makeConcreteTensor(shape); fusion.addInput(tv1); - auto tv2 = set(tv0); - auto tv3 = set(tv1); + auto tv2 = sin(tv0); + auto tv3 = sin(tv1); auto tv4 = cat({tv2, tv3}, -1); - auto tv5 = set(tv4); + auto tv5 = cos(tv4); fusion.addOutput(tv5); - // Propagate the pad op of each cat input - for (auto cat_inp : - ir_utils::filterByType(tv4->definition()->inputs())) { - auto pad_op = dynamic_cast(cat_inp->definition()); - ASSERT_NE(pad_op, nullptr); - scheduler_tools::propagateResizeToInputs(pad_op); - auto pad_inp = pad_op->input(0)->as(); - checkLoopDomainEquivalence(cat_inp, {pad_inp}); - } + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + auto t1 = at::randn({16, 100}, options); + std::vector inputs({t0, t1}); - auto ref_tv = tv4; + const bool use_scheduler = GetParam(); + + if (!use_scheduler) { + // Propagate the pad op of each cat input + for (auto cat_inp : + ir_utils::filterByType(tv4->definition()->inputs())) { + auto pad_op = dynamic_cast(cat_inp->definition()); + ASSERT_NE(pad_op, nullptr); + scheduler_tools::propagateResizeToInputs(pad_op); + auto pad_inp = pad_op->input(0)->as(); + checkLoopDomainEquivalence(cat_inp, {pad_inp}); + } - // At this point, all tensors should have the same loop domain - checkLoopDomainEquivalence(ref_tv); + auto ref_tv = tv4; - // Schedule the reference - ref_tv->flatten(); - // For TIDx - ref_tv->split(0, 128); - // For BIDx - ref_tv->split(0, 4); + // At this point, all tensors should have the same loop domain + checkLoopDomainEquivalence(ref_tv); - scheduler_tools::scheduleLoopDomainsLike( - fusion.allTvs(), ref_tv->getLoopDomain()); + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); - // Fusion should still have a uniform loop domain - checkLoopDomainEquivalence(ref_tv); + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain()); - inlineMost(); + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); - // All tensors, except for fusion inputs, should be fully inlined - for (auto tv : fusion.allTvs()) { - if (tv->isFusionInput()) { - continue; + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); } - EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); - } - ref_tv->axis(-1)->parallelize(ParallelType::TIDx); - ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); - auto t1 = at::randn({16, 100}, options); - std::vector inputs({t0, t1}); + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + } else { + GTEST_SKIP() << "Scheduling not yet supported"; - KernelExecutor ke; - ke.compile(&fusion, inputs); - auto outputs = ke.run(inputs); - testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto out_tensors = executor_cache.runFusionWithInputs(inputs); + testValidate( + executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = + dynamic_cast(runtime->executors().at(0).get()) + ->fusion(); + checkLoopDomainEquivalence( + scheduled_fusion->outputs().at(0)->as()); + } } // manual scheduling that should have vectorized load on padded inputs. diff --git a/tests/cpp/test_segmentation.cpp b/tests/cpp/test_segmentation.cpp index b893c5c29ad..8bcf3b13d60 100644 --- a/tests/cpp/test_segmentation.cpp +++ b/tests/cpp/test_segmentation.cpp @@ -552,51 +552,6 @@ TEST_F(SegmentationTest, ForceBf16NotAllCast) { } } -// Test that a segment with a slice does not introduce a cast -// See https://github.com/NVIDIA/Fuser/pull/1936 -TEST_F(SegmentationTest, SliceSegmentCasts) { - EnableOptionsGuard opt_guard; - EnableOptionsGuard::getCurOptions().set(EnableOption::IoToLowerPrecision); - - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1, DataType::Half); - - fusion->addInput(tv0); - - // Group 1 - auto tv1 = mul(tv0, tv0); - // Group 2 - auto tv2 = slice(tv1, {0}, {3}, {1}); - auto tv3 = add(tv2, tv2); - - fusion->addOutput(tv3); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto in0 = at::randn({5}, options); - auto outputs = executor_cache.runFusionWithInputs({in0}); - - SegmentedFusion* segmented_fusion = - executor_cache.getMostRecentKernelRuntime()->fusionSegments(); - - ASSERT_EQ(segmented_fusion->edges().size(), 1); - - SegmentedEdge* slice_edge = segmented_fusion->edges().at(0); - - // Expect edge to be half-precision - // TODO: Change this rhs to DataType::Half once we have addressed - // https://github.com/NVIDIA/Fuser/issues/1902 - EXPECT_EQ(slice_edge->val->getDataType(), DataType::Float); - - // There should be no cast before the slice - EXPECT_TRUE(slice_edge->val->uses().at(0)->isA()); - - testValidate(executor_cache.fusion(), outputs, {in0}, __LINE__, __FILE__); -} - TEST_F(SegmentationTest, codeGenSupportedMergeIssue1970) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 1e2baeffadf..0b74fddeae6 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -3331,7 +3331,10 @@ def fusion_func(fd: FusionDefinition) -> None: fd.add_output(T27) fd.add_output(T28) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True) + # TODO: Support segmentation. See #3594. + nvf_out, _ = self.exec_nvfuser( + fusion_func, inputs, is_clonable=True, supports_segmentation=False + ) t12 = inputs[1] * inputs[-2] t13 = torch.permute(t12, [0, 1, 3, 2]) @@ -4383,7 +4386,10 @@ def fusion_func(fd: FusionDefinition): fd.add_output(T18) fd.add_output(T16) - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=True) + # TODO: Support segmentation. See #3594. + nvf_out, _ = self.exec_nvfuser( + fusion_func, inputs, is_clonable=True, supports_segmentation=False + ) def test_returning_aliased_outputs(self): inputs = [torch.randn((1, 2, 3, 4), dtype=torch.float32, device="cuda:0")]