Skip to content

Commit 976bcdd

Browse files
authored
[Workaround]: cast i1 to i8 for printf parameters (#3628)
Seems `i1` type is not supported as `printf`'s parameter. This will trigger a segment fault in IGC. Extending `i1` to `i8` temporarily.
1 parent 980132b commit 976bcdd

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

python/test/unit/language/test_subprocess.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def is_interpreter():
4040
("device_print_negative", "int32"),
4141
("device_print_uint", "uint32"),
4242
("device_print_2d_tensor", "int32"),
43+
("device_print", "bool"),
4344
])
4445
def test_print(func_type: str, data_type: str, device: str):
4546
if device == "xpu" and data_type == "float64" and not tr.driver.active.get_current_target().arch['has_fp64']:
@@ -66,7 +67,7 @@ def test_print(func_type: str, data_type: str, device: str):
6667
# Format is
6768
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
6869
expected_lines = Counter()
69-
if func_type in ("print", "device_print", "device_print_uint"):
70+
if func_type in ("print", "device_print", "device_print_uint") and data_type != "bool":
7071
for i in range(N):
7172
offset = (1 << 31) if data_type == "uint32" else 0
7273
line = f"pid (0, 0, 0) idx ({i:3}) x: {i + offset}"
@@ -115,6 +116,10 @@ def test_print(func_type: str, data_type: str, device: str):
115116
for x in range(x_dim):
116117
for y in range(y_dim):
117118
expected_lines[f"pid (0, 0, 0) idx ({x}, {y:2}): {(x * y_dim + y)}"] = 1
119+
elif data_type == "bool":
120+
expected_lines["pid (0, 0, 0) idx ( 0) x: 0"] = 1
121+
for i in range(1, N):
122+
expected_lines[f"pid (0, 0, 0) idx ({i:3}) x: 1"] = 1
118123

119124
actual_lines = Counter()
120125
for line in outs:

third_party/intel/lib/TritonIntelGPUToLLVM/PrintOpToLLVM.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,15 @@ struct PrintOpConversion
160160
auto elem = elems[i];
161161

162162
os << getFormatSubstr(elem, hex, /*width=*/std::nullopt, isSigned);
163-
printfOperands.push_back(elem);
163+
if (isa<IntegerType>(elem.getType()) &&
164+
elem.getType().getIntOrFloatBitWidth() == 1) {
165+
// FIXME: There is some problem when using i1 type now,
166+
// remove this code once IGC fix the problem.
167+
TritonLLVMOpBuilder b(rewriter.getUnknownLoc(), rewriter);
168+
printfOperands.push_back(b.zext(i8_ty, elem));
169+
} else {
170+
printfOperands.push_back(elem);
171+
}
164172

165173
// It's the same format string each iteration, but it's a lot easier if we
166174
// construct the format string at the same time as we populate

0 commit comments

Comments
 (0)