-
| 
         My question is just as in the title. To elaborate, in the middle of a  As an illustrative example, I attach a relevant code block below. I apologize that the code block is long, but it is necessary to provide the full context for the  The early stopping conditions are checked and executed around the lines:   still_going = true;
  compute(stage_idx, i, &still_going);  // GEMM, also sets the value of still_going
  if (!still_going) {
      break;
  }[CODE] (long)    using PipelineTmaAsync = cutlass::PipelineTmaAsync<NUM_STAGES>;
    using PipelineState = cutlass::PipelineState<NUM_STAGES>;
    using BarrierType = typename PipelineTmaAsync::ProducerBarrierType;
    static constexpr auto num_consumers = cute::thr_size(TiledMma{});
    auto pipeline_params = typename PipelineTmaAsync::Params{};
    pipeline_params.transaction_bytes = tma_size_bytes;
    pipeline_params.role = PipelineTmaAsync::ThreadCategory::ProducerConsumer;
    pipeline_params.is_leader = threadIdx.x == 0;
    pipeline_params.num_consumers = num_consumers;
    auto pipeline = PipelineTmaAsync{shared_storage.pipeline, pipeline_params, ClusterShape{}};
    auto smem_pipe_read = PipelineState{};
    auto smem_pipe_write = cutlass::make_producer_start_state<PipelineTmaAsync>();
    const auto num_blocks_tma_prologue = cute::min(num_blocks, NUM_STAGES);
    const auto num_blocks_mma_prologue = cute::min(1, num_blocks_tma_prologue);
    const auto num_blocks_mma_mainloop = num_blocks - num_blocks_mma_prologue;
    /********************************************************************
     * `still_going` tracks whether an early-stopping condition is met. *
     ********************************************************************/
    bool still_going = false;
    int block_idx = 0;
    // TMA Prologue
    CUTE_UNROLL
    for (int i = 0; i < num_blocks_tma_prologue; ++i) {
        pipeline.producer_acquire(smem_pipe_write);
        auto stage_idx = smem_pipe_write.index();
        auto tma_mbar = pipeline.producer_get_barrier(smem_pipe_write);
        fetch_data(tma_mbar, i, stage_idx);  // involves a TMA load
        pipeline.producer_commit(smem_pipe_write, tma_size_bytes);
        ++smem_pipe_write;
    }
    block_idx += num_blocks_tma_prologue;
    // MMA Prologue
    CUTE_NO_UNROLL
    for (int i = 0; i < num_blocks_mma_prologue; ++i) {
        pipeline.consumer_wait(smem_pipe_read);
        auto stage_idx = smem_pipe_read.index();
        still_going = true;
        compute(stage_idx, i, &still_going);  // GEMM, also sets the value of still_going
        if (!still_going) {
            break;
        }
        pipeline.consumer_release(smem_pipe_read);
        ++smem_pipe_read;
    }
    // Main loop: MMA and TMA.
    CUTE_NO_UNROLL
    for (int i = 0; i < num_blocks_mma_mainloop; ++i) {
        pipeline.consumer_wait(smem_pipe_read);
        auto stage_idx = smem_pipe_read.index();
        still_going = false;
        compute(stage_idx, i, &still_going);  // GEMM, also sets the value of still_going
        if (!still_going) {
            break;
        }
        // next read stage
        if (block_idx < num_blocks) {
            pipeline.producer_acquire(smem_pipe_write);
            auto stage_idx = smem_pipe_write.index();
            auto tma_mbar = pipeline.producer_get_barrier(smem_pipe_write);
            fetch_data(tma_mbar, block_idx, stage_idx);  // involves a TMA load 
            pipeline.producer_commit(smem_pipe_write, tma_size_bytes);
            ++smem_pipe_write;
            ++block_idx;
        }
        pipeline.consumer_release(smem_pipe_read);
        ++smem_pipe_read;
    }
    // Wait on all GMMAs
    cute::warpgroup_wait<0>();
    cute::warpgroup_fence_operand(rO);
    if constexpr (size(ClusterShape{}) > 1) {
        cute::cluster_sync();
    } else {
        __syncthreads();
    }Truly appreciate your help!  | 
  
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
| 
         Yes this is fine in general. Ideally we would model this as some kind of while loop around an updating k tile counter etc. You just have to be really careful to make sure the pipeline states for producers and consumers agree if you terminate early in case this is a persistent kernel or you are fusing with another collective later in the lifetime of kernel. 
  | 
  
Beta Was this translation helpful? Give feedback.
Yes this is fine in general. Ideally we would model this as some kind of while loop around an updating k tile counter etc. You just have to be really careful to make sure the pipeline states for producers and consumers agree if you terminate early in case this is a persistent kernel or you are fusing with another collective later in the lifetime of kernel.
still_goingfor each other. Is that happening insidecompute()?