Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/nvidia/fuser into autotune_…
Browse files Browse the repository at this point in the history
…outer_reduction
  • Loading branch information
rdspring1 committed Jan 7, 2025
2 parents 4b29b02 + 3ae5468 commit ea07803
Show file tree
Hide file tree
Showing 113 changed files with 8,540 additions and 3,121 deletions.
28 changes: 20 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/host_ir/container.cpp
${NVFUSER_SRCS_DIR}/host_ir/executor.cpp
${NVFUSER_SRCS_DIR}/host_ir/host_ir.cpp
${NVFUSER_SRCS_DIR}/host_ir/lower.cpp
${NVFUSER_SRCS_DIR}/id_model/circular_buffer_indexing.cpp
${NVFUSER_SRCS_DIR}/id_model/contiguity.cpp
${NVFUSER_SRCS_DIR}/id_model/id_model.cpp
Expand Down Expand Up @@ -170,7 +171,6 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp
${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp
${NVFUSER_SRCS_DIR}/multidevice/executor.cpp
${NVFUSER_SRCS_DIR}/multidevice/lower_communication.cpp
${NVFUSER_SRCS_DIR}/multidevice/utils.cpp
${NVFUSER_SRCS_DIR}/mutator.cpp
${NVFUSER_SRCS_DIR}/non_divisible_split.cpp
Expand Down Expand Up @@ -200,6 +200,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp
${NVFUSER_SRCS_DIR}/rng.cpp
${NVFUSER_SRCS_DIR}/runtime/allocations.cpp
${NVFUSER_SRCS_DIR}/runtime/executor.cpp
Expand Down Expand Up @@ -231,8 +232,10 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/registry.cpp
${NVFUSER_SRCS_DIR}/scheduler/registry_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/resize.cpp
${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp
${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/domain_map.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/maxinfo_propagator.cpp
Expand Down Expand Up @@ -294,13 +297,18 @@ endif()
add_library(codegen_internal OBJECT ${NVFUSER_SRCS})

if(NOT MSVC)
# -Werror is not enabled, because of gcc 12.2 used in manylinux image.
# consider enable this when we upgrade. linking comment:
# https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266
target_compile_options(codegen_internal PRIVATE
-Wall -Wno-unused-function
# -Werror
)
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
target_compile_options(codegen_internal PRIVATE
-Wall -Wno-unused-function -Werror
# These warnings are not treated as errors because of gcc 12.2 used in
# manylinux image. consider enable this when we upgrade.
# linking comment:
# https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266
-Wno-error=restrict -Wno-error=stringop-overflow)
else()
target_compile_options(codegen_internal PRIVATE
-Wall -Wno-unused-function -Werror)
endif()
endif()

target_compile_definitions(codegen_internal PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB")
Expand Down Expand Up @@ -440,6 +448,7 @@ if(BUILD_PYTHON)
list(APPEND NVFUSER_PYTHON_SRCS
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp
${NVFUSER_SRCS_DIR}/python_frontend/schedule_bindings.cpp
)

add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS})
Expand Down Expand Up @@ -573,6 +582,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_resharding.cpp
${NVFUSER_ROOT}/tests/cpp/test_resize.cpp
${NVFUSER_ROOT}/tests/cpp/test_reduction_pointwise.cpp
${NVFUSER_ROOT}/tests/cpp/test_rope.cpp
${NVFUSER_ROOT}/tests/cpp/test_scalar_hoisting.cpp
${NVFUSER_ROOT}/tests/cpp/test_scatter_gather.cpp
${NVFUSER_ROOT}/tests/cpp/test_sdpa_node.cpp
Expand All @@ -584,6 +594,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_tensor_factories.cpp
${NVFUSER_ROOT}/tests/cpp/test_unary.cpp
${NVFUSER_ROOT}/tests/cpp/test_utils.cpp
${NVFUSER_ROOT}/tests/cpp/test_vectorization_analysis.cpp
)

if(BUILD_TEST)
Expand Down Expand Up @@ -644,6 +655,7 @@ if(BUILD_TEST)
set(MULTIDEVICE_TEST_SRCS)
list(APPEND MULTIDEVICE_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/multidevice.cpp
${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_overlap.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp
Expand Down
1 change: 0 additions & 1 deletion benchmarks/python/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def torchprofile_timer(self) -> float:
# Clear the internal profiler object to avoid accumulating function events and then restart the profiler
# See PR: https://github.com/pytorch/pytorch/pull/125510
self.prof.profiler = None
self.prof.start()

return self.current_time

Expand Down
71 changes: 71 additions & 0 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,77 @@ class BFS {
Direction allowed_direction_ = Direction::Undefined;
};

// Unlike the default BFS behavior, Expr is considered ready to
// visit as long as one of the inputs or outputs has any of its dependencies met
template <
typename ExprT,
typename ValT,
typename DefinitionT,
typename UsesT,
typename InputsT,
typename OutputsT>
class BFSWithPermissiveDependence
: public BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT> {
public:
using NodeType =
typename BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>::
NodeType;

BFSWithPermissiveDependence(
DefinitionT definition,
UsesT uses,
InputsT inputs,
OutputsT outputs,
std::vector<NodeType> from,
std::vector<NodeType> to,
bool require_all_to_visited = true,
Direction allowed_direction = Direction::Undefined)
: BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>(
definition,
uses,
inputs,
outputs,
std::move(from),
std::move(to),
require_all_to_visited,
allowed_direction) {}

std::optional<std::pair<Direction, std::vector<NodeType>>> isReady(
const ExprT& expr) const override {
// Either any inputs or any outputs must have been visited
decltype(auto) inputs = this->inputs_(expr);
if (!inputs.empty() && this->allowed_direction_ != Direction::Backward &&
std::any_of(
inputs.begin(), inputs.end(), [&](const ValT& input) -> bool {
return this->isDependencySatisfied(input);
})) {
std::vector<NodeType> prev_nodes;
std::copy_if(
inputs.begin(),
inputs.end(),
std::back_inserter(prev_nodes),
[&](const ValT& input) -> bool { return this->isVisited(input); });
return std::make_pair(Direction::Forward, prev_nodes);
}

decltype(auto) outputs = this->outputs_(expr);
if (!outputs.empty() && this->allowed_direction_ != Direction::Forward &&
std::any_of(
outputs.begin(), outputs.end(), [&](const ValT& output) -> bool {
return this->isDependencySatisfied(output);
})) {
std::vector<NodeType> prev_nodes;
std::copy_if(
outputs.begin(),
outputs.end(),
std::back_inserter(prev_nodes),
[&](const ValT& output) -> bool { return this->isVisited(output); });
return std::make_pair(Direction::Backward, prev_nodes);
}
return std::nullopt;
}
};

// Find the shortest path from the from vals to the to
// vals. Dependency between vals and exprs must be satisfied.
// It is an error if no valid path is found unless
Expand Down
29 changes: 19 additions & 10 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3026,17 +3026,22 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
} else {
step_code << gen_index << " += " << gen_step;
}
if (loop->isUnrolled()) {
indent() << "#pragma unroll\n";
} else if (
loop->circularBufferLoopStage() == CircularBufferLoopStage::Epilog) {
indent() << "#pragma unroll " << loop->circularBufferLoopStageDepth() - 1
<< "\n";
} else if (
loop->circularBufferLoopStage() !=
if (loop->circularBufferLoopStage() !=
CircularBufferLoopStage::NotApplicable) {
indent() << "#pragma unroll " << loop->circularBufferLoopStageDepth()
<< "\n";
// NOTE: requireUnroll is sometimes called on a circular-buffered matmul
// loops when static shapes are used. To avoid hinting that the compiler
// should maximally unroll such loops leading to very long compiles, we
// handle that case explicitly here and ignore loop->isUnrolled().
//
// Unroll "prefetch" many circular buffered loops regardless of buffer
// stage (prologue, main, or epilogue)
int64_t prefetch = kernel_->summary()
.circular_buffer_info
.getCircularBufferOptionsFor(loop->iter_domain())
.prefetch;
indent() << "#pragma unroll " << prefetch << "\n";
} else if (loop->isUnrolled()) {
indent() << "#pragma unroll\n";
} else {
indent() << "#pragma unroll 1\n";
}
Expand Down Expand Up @@ -3505,6 +3510,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
indent() << "NVFUSER_UPDATE_MAGIC_ZERO;\n";
}

void handle(const kir::Return* ret) final {
indent() << "return;\n";
}

private:
std::stringstream code_;
const kir::Kernel* kernel_;
Expand Down
6 changes: 4 additions & 2 deletions csrc/device_lower/analysis/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,11 @@ IterDomain* CircularBufferInfo::getCircularBufferAxis(

const CircularBufferOptions& CircularBufferInfo::getCircularBufferOptionsFor(
IterDomain* circular_buffer_axis) const {
auto concrete_id = lower_utils::getConcreteLoopID(circular_buffer_axis);
if (GpuLower::hasCurrent()) {
circular_buffer_axis = lower_utils::getConcreteLoopID(circular_buffer_axis);
}

auto maybe_depth_it = circular_buffer_options_.find(concrete_id);
auto maybe_depth_it = circular_buffer_options_.find(circular_buffer_axis);

NVF_ERROR(
maybe_depth_it != circular_buffer_options_.end(),
Expand Down
15 changes: 15 additions & 0 deletions csrc/device_lower/analysis/device_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <cuda.h>

#include <device_lower/analysis/device_version.h>
#include <device_lower/lower2device.h>
#include <mma_type.h>
Expand All @@ -19,9 +21,22 @@ void MinimumDeviceVersion::dispatch(Val* val) {
}
if (val->dtype() == DataType::Float8_e4m3fn ||
val->dtype() == DataType::Float8_e5m2) {
// See release note
// https://docs.nvidia.com/cuda/archive/12.1.0/parallel-thread-execution/index.html#ptx-isa-version-8-1
#if (CUDA_VERSION >= 12010)
ensureVersion(
{8, 9},
"Fusion contains Float8_xxx values which was introduced in Ada (8.9)");
// See release note
// https://docs.nvidia.com/cuda/archive/11.8.0/parallel-thread-execution/index.html#ptx-isa-version-7-8
#elif (CUDA_VERSION >= 11080)
ensureVersion(
{9, 0},
"Fusion contains Float8_xxx values which was introduced in Hopper (9.0)");
#else
NVF_ERROR(
"Fusion contains Float8_xxx values which was not supported in given CUDA version");
#endif // (CUDA_VERSION >= 12010)
}
IterVisitor::dispatch(val);
}
Expand Down
18 changes: 14 additions & 4 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1686,8 +1686,8 @@ Val* hardCodedIndexGenerationForStMatrix(
Val* out_index = nullptr;

NVF_ERROR(
ldst->out()->dtype() == DataType::Half,
"we only support half type in stmatrix");
dataTypeSize(ldst->out()->dtype()) == 2,
"we only support 16-bit types in stmatrix");

NVF_ERROR(ldst->out()->isA<TensorView>());
TensorView* out_tv = ldst->out()->as<TensorView>();
Expand Down Expand Up @@ -1959,8 +1959,8 @@ Val* hardCodedIndexGenerationForStMatrixSwizzle(
"size not currently supported for stmatrix");

NVF_ERROR(
ldst->out()->dtype() == DataType::Half,
"we only support half type in stmatrix");
dataTypeSize(ldst->out()->dtype()) == 2,
"we only support 16-bit types in stmatrix");

NVF_ERROR(ldst->out()->isA<TensorView>());
TensorView* out_tv = ldst->out()->as<TensorView>();
Expand Down Expand Up @@ -2583,6 +2583,16 @@ void IndexLowering::handle(const kir::WgMmaFence* fence) {
pushBack(const_cast<kir::WgMmaFence*>(fence)); // NOLINT
}

void IndexLowering::handle(const kir::SetMaxNReg* maxnreg) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::SetMaxNReg*>(maxnreg)); // NOLINT
}

void IndexLowering::handle(const kir::Return* ret) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::Return*>(ret)); // NOLINT
}

void IndexLowering::handle(const kir::AsyncCommit* commit) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::AsyncCommit*>(commit)); // NOLINT
Expand Down
2 changes: 2 additions & 0 deletions csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class IndexLowering : private OptOutConstDispatch {
void handle(const kir::GridSync*) final;
void handle(const kir::FenceAsyncProxy*) final;
void handle(const kir::WgMmaFence*) final;
void handle(const kir::SetMaxNReg*) final;
void handle(const kir::Return*) final;
void handle(const kir::MBarrierInit*) final;
void handle(const kir::MBarrierInvalidate*) final;
void handle(const kir::MBarrierArrive*) final;
Expand Down
13 changes: 13 additions & 0 deletions csrc/device_lower/pass/inline_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,19 @@ class LowerToInlinePtx : public kir::ExprMutator {
std::vector<Val*>{},
kir::Asm::Options{/*volatile=*/true}));
}

void handle(kir::SetMaxNReg* maxnreg) final {
std::string ptx = (maxnreg->increaseRegisters())
? "setmaxnreg.inc.sync.aligned.u32"
: "setmaxnreg.dec.sync.aligned.u32";
registerReplace(
maxnreg,
IrBuilder::create<kir::Asm>(
ptx,
std::vector<Val*>{},
std::vector<Val*>{maxnreg->numberOfRegisters()},
kir::Asm::Options{/*volatile=*/true}));
}
};

std::vector<Expr*> lowerToInlinePtx(const std::vector<Expr*>& exprs) {
Expand Down
3 changes: 3 additions & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class Val;
f(GridSync); \
f(FenceAsyncProxy); \
f(WgMmaFence); \
f(SetMaxNReg); \
f(Return); \
f(MBarrierInit); \
f(MBarrierInvalidate); \
f(MBarrierArrive); \
Expand All @@ -146,6 +148,7 @@ class Val;
f(HostUnit); \
f(PostOnStream); \
f(SetCurrentStream); \
f(GetCurrentStream); \
f(Wait); \
f(Synchronize); \
f(StartCoalescing); \
Expand Down
Loading

0 comments on commit ea07803

Please sign in to comment.