Skip to content

Commit 1409e1a

Browse files
authored
[SPIRV] Add logic for OpGenericCastToPtrExplicit rewriting (#146596)
This PR adds overrides in `SPIRVTTIImpl` for `collectFlatAddressOperands` and `rewriteIntrinsicWithAddressSpace` to enable `InferAddressSpacesPass` to rewrite the `llvm.spv.generic.cast.to.ptr.explicit` intrinsic (corresponding to `OpGenericCastToPtrExplicit`) when the address space of the argument can be inferred. When the destination address space of the cast matches the inferred address space of the argument, the call is replaced with that argument. When they do not match, the cast is replaced with a constant null pointer.
1 parent 763db38 commit 1409e1a

File tree

4 files changed

+150
-6
lines changed

4 files changed

+150
-6
lines changed

llvm/lib/Target/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ add_llvm_target(SPIRVCodeGen
4444
SPIRVRegularizer.cpp
4545
SPIRVSubtarget.cpp
4646
SPIRVTargetMachine.cpp
47+
SPIRVTargetTransformInfo.cpp
4748
SPIRVUtils.cpp
4849
SPIRVEmitNonSemanticDI.cpp
4950

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===- SPIRVTargetTransformInfo.cpp - SPIR-V specific TTI -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "SPIRVTargetTransformInfo.h"
10+
#include "llvm/IR/IntrinsicsSPIRV.h"
11+
12+
using namespace llvm;
13+
14+
bool llvm::SPIRVTTIImpl::collectFlatAddressOperands(
15+
SmallVectorImpl<int> &OpIndexes, Intrinsic::ID IID) const {
16+
switch (IID) {
17+
case Intrinsic::spv_generic_cast_to_ptr_explicit:
18+
OpIndexes.push_back(0);
19+
return true;
20+
default:
21+
return false;
22+
}
23+
}
24+
25+
Value *llvm::SPIRVTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
26+
Value *OldV,
27+
Value *NewV) const {
28+
auto IntrID = II->getIntrinsicID();
29+
switch (IntrID) {
30+
case Intrinsic::spv_generic_cast_to_ptr_explicit: {
31+
unsigned NewAS = NewV->getType()->getPointerAddressSpace();
32+
unsigned DstAS = II->getType()->getPointerAddressSpace();
33+
return NewAS == DstAS ? NewV
34+
: ConstantPointerNull::get(
35+
PointerType::get(NewV->getContext(), DstAS));
36+
}
37+
default:
38+
return nullptr;
39+
}
40+
}

llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,15 @@ class SPIRVTTIImpl final : public BasicTTIImplBase<SPIRVTTIImpl> {
5050
}
5151

5252
unsigned getFlatAddressSpace() const override {
53-
if (ST->isShader())
54-
return 0;
55-
// FIXME: Clang has 2 distinct address space maps. One where
53+
// Clang has 2 distinct address space maps. One where
5654
// default=4=Generic, and one with default=0=Function. This depends on the
57-
// environment. For OpenCL, we don't need to run the InferAddrSpace pass, so
58-
// we can return -1, but we might want to fix this.
59-
return -1;
55+
// environment.
56+
return ST->isShader() ? 0 : 4;
6057
}
58+
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
59+
Intrinsic::ID IID) const override;
60+
Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
61+
Value *NewV) const override;
6162
};
6263

6364
} // namespace llvm
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
; This test checks that the address space casts for SPIR-V generic pointer casts
2+
; are lowered correctly by the infer-address-spaces pass.
3+
; RUN: opt < %s -passes=infer-address-spaces -S --mtriple=spirv64-unknown-unknown | FileCheck %s
4+
5+
; Casting a global pointer to a global pointer.
6+
; The uses of c2 will be replaced with %global.
7+
; CHECK: @kernel1(ptr addrspace(1) %global)
8+
define i1 @kernel1(ptr addrspace(1) %global) {
9+
%c1 = addrspacecast ptr addrspace(1) %global to ptr addrspace(4)
10+
%c2 = call ptr addrspace(1) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1)
11+
; CHECK: %b1 = icmp eq ptr addrspace(1) %global, null
12+
%b1 = icmp eq ptr addrspace(1) %c2, null
13+
ret i1 %b1
14+
}
15+
16+
; Casting a global pointer to a local pointer.
17+
; The uses of c2 will be replaced with null.
18+
; CHECK: @kernel2(ptr addrspace(1) %global)
19+
define i1 @kernel2(ptr addrspace(1) %global) {
20+
%c1 = addrspacecast ptr addrspace(1) %global to ptr addrspace(4)
21+
%c2 = call ptr addrspace(3) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1)
22+
; CHECK: %b1 = icmp eq ptr addrspace(3) null, null
23+
%b1 = icmp eq ptr addrspace(3) %c2, null
24+
ret i1 %b1
25+
}
26+
27+
; Casting a global pointer to a private pointer.
28+
; The uses of c2 will be replaced with null.
29+
; CHECK: @kernel3(ptr addrspace(1) %global)
30+
define i1 @kernel3(ptr addrspace(1) %global) {
31+
%c1 = addrspacecast ptr addrspace(1) %global to ptr addrspace(4)
32+
%c2 = call ptr @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1)
33+
; CHECK: %b1 = icmp eq ptr null, null
34+
%b1 = icmp eq ptr %c2, null
35+
ret i1 %b1
36+
}
37+
38+
; Casting a local pointer to a local pointer.
39+
; The uses of c2 will be replaced with %local.
40+
; CHECK: @kernel4(ptr addrspace(3) %local)
41+
define i1 @kernel4(ptr addrspace(3) %local) {
42+
%c1 = addrspacecast ptr addrspace(3) %local to ptr addrspace(4)
43+
%c2 = call ptr addrspace(3) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1)
44+
; CHECK: %b1 = icmp eq ptr addrspace(3) %local, null
45+
%b1 = icmp eq ptr addrspace(3) %c2, null
46+
ret i1 %b1
47+
}
48+
49+
; Casting a local pointer to a global pointer.
50+
; The uses of c2 will be replaced with null.
51+
; CHECK: @kernel5(ptr addrspace(3) %local)
52+
define i1 @kernel5(ptr addrspace(3) %local) {
53+
%c1 = addrspacecast ptr addrspace(3) %local to ptr addrspace(4)
54+
%c2 = call ptr addrspace(1) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1)
55+
; CHECK: %b1 = icmp eq ptr addrspace(1) null, null
56+
%b1 = icmp eq ptr addrspace(1) %c2, null
57+
ret i1 %b1
58+
}
59+
60+
; Casting a local pointer to a private pointer.
61+
; The uses of c2 will be replaced with null.
62+
; CHECK: @kernel6(ptr addrspace(3) %local)
63+
define i1 @kernel6(ptr addrspace(3) %local) {
64+
%c1 = addrspacecast ptr addrspace(3) %local to ptr addrspace(4)
65+
%c2 = call ptr @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1)
66+
; CHECK: %b1 = icmp eq ptr null, null
67+
%b1 = icmp eq ptr %c2, null
68+
ret i1 %b1
69+
}
70+
71+
; Casting a private pointer to a private pointer.
72+
; The uses of c2 will be replaced with %private.
73+
; CHECK: @kernel7(ptr %private)
74+
define i1 @kernel7(ptr %private) {
75+
%c1 = addrspacecast ptr %private to ptr addrspace(4)
76+
%c2 = call ptr @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1)
77+
; CHECK: %b1 = icmp eq ptr %private, null
78+
%b1 = icmp eq ptr %c2, null
79+
ret i1 %b1
80+
}
81+
82+
; Casting a private pointer to a global pointer.
83+
; The uses of c2 will be replaced with null.
84+
; CHECK: @kernel8(ptr %private)
85+
define i1 @kernel8(ptr %private) {
86+
%c1 = addrspacecast ptr %private to ptr addrspace(4)
87+
%c2 = call ptr addrspace(1) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1)
88+
; CHECK: %b1 = icmp eq ptr addrspace(1) null, null
89+
%b1 = icmp eq ptr addrspace(1) %c2, null
90+
ret i1 %b1
91+
}
92+
93+
; Casting a private pointer to a local pointer.
94+
; The uses of c2 will be replaced with null.
95+
; CHECK: @kernel9(ptr %private)
96+
define i1 @kernel9(ptr %private) {
97+
%c1 = addrspacecast ptr %private to ptr addrspace(4)
98+
%c2 = call ptr addrspace(3) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1)
99+
; CHECK: %b1 = icmp eq ptr addrspace(3) null, null
100+
%b1 = icmp eq ptr addrspace(3) %c2, null
101+
ret i1 %b1
102+
}

0 commit comments

Comments
 (0)