diff --git a/modules/compiler/test/lit/passes/work-item-loop-array.ll b/modules/compiler/test/lit/passes/work-item-loop-array.ll new file mode 100644 index 000000000..0f49e8892 --- /dev/null +++ b/modules/compiler/test/lit/passes/work-item-loop-array.ll @@ -0,0 +1,29 @@ +; Copyright (C) Codeplay Software Limited +; +; Licensed under the Apache License, Version 2.0 (the "License") with LLVM +; Exceptions; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +; License for the specific language governing permissions and limitations +; under the License. +; +; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +; RUN: muxc --passes work-item-loops,verify < %s | FileCheck %s +target triple = "spir64-unknown-unknown" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" + + +; CHECK: @llvm.used = appending global [1 x ptr] [ptr @foo.mux-barrier-wrapper], section "llvm.metadata" +@llvm.used = appending global [1 x ptr] [ptr @foo], section "llvm.metadata" + + +define void @foo() #0 { + ret void +} +attributes #0 = { convergent norecurse nounwind "mux-kernel"="entry-point" } diff --git a/modules/compiler/utils/include/compiler/utils/pass_functions.h b/modules/compiler/utils/include/compiler/utils/pass_functions.h index 048cf1a0e..d0252ec5e 100644 --- a/modules/compiler/utils/include/compiler/utils/pass_functions.h +++ b/modules/compiler/utils/include/compiler/utils/pass_functions.h @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -75,6 +76,18 @@ void replaceConstantExpressionWithInstruction(llvm::Constant *const constant); void remapConstantExpr(llvm::ConstantExpr *expr, llvm::Constant *from, llvm::Constant *to); +/// @brief remap operands of a constant array +/// +/// @note This will create a new constant array and replace references to +/// the original constant with the new one +/// +/// @param[in] arr Constant array to be remapped +/// @param[in] from Constant which if found in array will be +/// replaced +/// @param[in] to Constant which will replace any operands which are `from` +void remapConstantArray(llvm::ConstantArray *arr, llvm::Constant *from, + llvm::Constant *to); + /// @brief Discover if input function references debug info metadata nodes /// /// @param[in] func Function to check diff --git a/modules/compiler/utils/source/pass_functions.cpp b/modules/compiler/utils/source/pass_functions.cpp index 0d93ce43a..c9d52c3f0 100644 --- a/modules/compiler/utils/source/pass_functions.cpp +++ b/modules/compiler/utils/source/pass_functions.cpp @@ -73,20 +73,35 @@ uint64_t computeApproximatePrivateMemoryUsage(const llvm::Function &fn) { return bytes; } -void remapConstantExpr(llvm::ConstantExpr *expr, llvm::Constant *from, - llvm::Constant *to) { - llvm::SmallVector newOps; - // iterate through the constant expression and create a vector of old and new +static llvm::SmallVector getNewOps(llvm::Constant *constant, + llvm::Constant *from, + llvm::Constant *to) { + llvm::SmallVector newOps; + // iterate through the constant and create a vector of old and new // ones - for (unsigned i = 0, e = expr->getNumOperands(); i != e; ++i) { - auto op = expr->getOperand(i); + for (unsigned i = 0, e = constant->getNumOperands(); i != e; ++i) { + auto op = constant->getOperand(i); if (op == from) { newOps.push_back(to); } else { - newOps.push_back(op); + newOps.push_back(llvm::cast(op)); } } + return newOps; +} +void remapConstantArray(llvm::ConstantArray *arr, llvm::Constant *from, + llvm::Constant *to) { + llvm::SmallVector newOps = getNewOps(arr, from, to); + // Create a new array with the list of operands and replace all uses with + llvm::Constant *newConstant = + llvm::ConstantArray::get(arr->getType(), newOps); + arr->replaceAllUsesWith(newConstant); +} + +void remapConstantExpr(llvm::ConstantExpr *expr, llvm::Constant *from, + llvm::Constant *to) { + llvm::SmallVector newOps = getNewOps(expr, from, to); // Create a new expression with the list of operands and replace all uses with llvm::Constant *newConstant = expr->getWithOperands(newOps); expr->replaceAllUsesWith(newConstant); diff --git a/modules/compiler/utils/source/work_item_loops_pass.cpp b/modules/compiler/utils/source/work_item_loops_pass.cpp index af8e5fcdf..766275949 100644 --- a/modules/compiler/utils/source/work_item_loops_pass.cpp +++ b/modules/compiler/utils/source/work_item_loops_pass.cpp @@ -1716,10 +1716,12 @@ Function *compiler::utils::WorkItemLoopsPass::makeWrapperFunction( for (auto *user : refF.users()) { if (ConstantExpr *constant = dyn_cast(user)) { remapConstantExpr(constant, &refF, new_wrapper); + } else if (ConstantArray *ca = dyn_cast(user)) { + remapConstantArray(ca, &refF, new_wrapper); } else if (!isa(user)) { llvm_unreachable( "Cannot handle user of function being anything other than a " - "ConstantExpr or CallInst"); + "ConstantExpr, ConstantArray or CallInst"); } }