Skip to content

Conversation

@clementval
Copy link
Contributor

@clementval clementval requested a review from wangzpgi October 27, 2025 20:58
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 27, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

As described in the programming guide: https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/#load-and-store-functions-using-bulk-tma-operations


Full diff: https://github.com/llvm/llvm-project/pull/165316.diff

4 Files Affected:

  • (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+2)
  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+60)
  • (modified) flang/module/cudadevice.f90 (+21-6)
  • (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+22)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index c3cd119b96174..ed0cbd3bdf16b 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -211,6 +211,8 @@ struct IntrinsicLibrary {
   mlir::Value genBarrierArrive(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genBarrierArriveCnt(mlir::Type, llvm::ArrayRef<mlir::Value>);
   void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
+  mlir::Value genBarrierTryWait(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genBesselJn(mlir::Type,
                                  llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genBesselYn(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 6b02fefb92196..645540ad22e51 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -50,6 +50,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -358,6 +359,14 @@ static constexpr IntrinsicHandler handlers[]{
      &I::genBarrierInit,
      {{{"barrier", asAddr}, {"count", asValue}}},
      /*isElemental=*/false},
+    {"barrier_try_wait",
+     &I::genBarrierTryWait,
+     {{{"barrier", asAddr}, {"token", asValue}}},
+     /*isElemental=*/false},
+    {"barrier_try_wait_sleep",
+     &I::genBarrierTryWaitSleep,
+     {{{"barrier", asAddr}, {"token", asValue}, {"ns", asValue}}},
+     /*isElemental=*/false},
     {"bessel_jn",
      &I::genBesselJn,
      {{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -3280,6 +3289,57 @@ void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
   mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);
 }
 
+// BARRIER_TRY_WAIT (CUDA)
+mlir::Value
+IntrinsicLibrary::genBarrierTryWait(mlir::Type resultType,
+                                    llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 2);
+  mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
+  mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
+  fir::StoreOp::create(builder, loc, zero, res);
+  mlir::Value ns =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 1000000);
+  mlir::Value load = fir::LoadOp::create(builder, loc, res);
+  auto whileOp = mlir::scf::WhileOp::create(
+      builder, loc, mlir::TypeRange{resultType}, mlir::ValueRange{load});
+  mlir::Block *beforeBlock = builder.createBlock(&whileOp.getBefore());
+  mlir::Value beforeArg = beforeBlock->addArgument(resultType, loc);
+  builder.setInsertionPointToStart(beforeBlock);
+  mlir::Value condition = mlir::arith::CmpIOp::create(
+      builder, loc, mlir::arith::CmpIPredicate::ne, beforeArg, zero);
+  mlir::scf::ConditionOp::create(builder, loc, condition, beforeArg);
+  mlir::Block *afterBlock = builder.createBlock(&whileOp.getAfter());
+  mlir::Value afterArg = afterBlock->addArgument(resultType, loc);
+  builder.setInsertionPointToStart(afterBlock);
+  auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+  auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
+  mlir::Value ret =
+      mlir::NVVM::InlinePtxOp::create(
+          builder, loc, {resultType}, {barrier, args[1], ns}, {},
+          ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
+          "selp.b32 %0, 1, 0, p;",
+          {})
+          .getResult(0);
+  mlir::scf::YieldOp::create(builder, loc, ret);
+  builder.setInsertionPointAfter(whileOp);
+  return whileOp.getResult(0);
+}
+
+// BARRIER_TRY_WAIT_SLEEP (CUDA)
+mlir::Value
+IntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
+                                         llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 3);
+  auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+  auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
+  return mlir::NVVM::InlinePtxOp::create(
+             builder, loc, {resultType}, {barrier, args[1], args[2]}, {},
+             ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
+             "selp.b32 %0, 1, 0, p;",
+             {})
+      .getResult(0);
+}
+
 // BESSEL_JN
 fir::ExtendedValue
 IntrinsicLibrary::genBesselJn(mlir::Type resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 5182950cbffea..ea54c974c9e7c 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -1998,6 +1998,18 @@ attributes(device,host) logical function on_device() bind(c)
 
   ! TMA Operations
 
+  interface barrier_arrive
+    attributes(device) function barrier_arrive(barrier) result(token)
+      integer(8), shared :: barrier
+      integer(8) :: token
+    end function
+    attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
+      integer(8), shared :: barrier
+      integer(4), value :: count
+      integer(8) :: token
+    end function
+  end interface
+
   interface 
     attributes(device) subroutine barrier_init(barrier, count)
       integer(8), shared :: barrier
@@ -2005,15 +2017,18 @@ attributes(device) subroutine barrier_init(barrier, count)
     end subroutine
   end interface
 
-  interface barrier_arrive
-    attributes(device) function barrier_arrive(barrier) result(token)
+  interface
+    attributes(device) integer function barrier_try_wait(barrier, token)
       integer(8), shared :: barrier
-      integer(8) :: token
+      integer(8), value  :: token
     end function
-    attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
+  end interface
+  
+  interface
+    attributes(device) integer function barrier_try_wait_sleep(barrier, token, ns)
       integer(8), shared :: barrier
-      integer(4), value :: count
-      integer(8) :: token
+      integer(8), value  :: token
+      integer(4), value  :: ns
     end function
   end interface
 
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 7d6caf58d71b3..27bca167be2ed 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -479,3 +479,25 @@ end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_bulk_s2g
 ! CHECL: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+
+attributes(global) subroutine test_barrier_try_wait()
+  integer :: istat
+  integer(8), shared :: barrier1
+  integer(8) :: token
+  istat = barrier_try_wait(barrier1, token)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_barrier_try_wait()
+! CHECK: scf.while
+! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %{{.*}}, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %c1000000{{.*}} : !llvm.ptr, i64, i32) -> i32
+
+attributes(global) subroutine test_barrier_try_wait_sleep()
+  integer :: istat
+  integer(8), shared :: barrier1
+  integer(8) :: token
+  integer(4) :: sleep_time
+  istat = barrier_try_wait_sleep(barrier1, token, sleep_time)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
+! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32

@clementval clementval enabled auto-merge (squash) October 28, 2025 17:19
@clementval clementval merged commit 56c1d35 into llvm:main Oct 28, 2025
8 of 9 checks passed
@clementval clementval deleted the cuf_tma_try_wait branch October 28, 2025 17:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants