@@ -40,6 +40,7 @@ def is_interpreter():
40
40
("device_print_negative" , "int32" ),
41
41
("device_print_uint" , "uint32" ),
42
42
("device_print_2d_tensor" , "int32" ),
43
+ ("device_print" , "bool" ),
43
44
])
44
45
def test_print (func_type : str , data_type : str , device : str ):
45
46
if device == "xpu" and data_type == "float64" and not tr .driver .active .get_current_target ().arch ['has_fp64' ]:
@@ -66,7 +67,7 @@ def test_print(func_type: str, data_type: str, device: str):
66
67
# Format is
67
68
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
68
69
expected_lines = Counter ()
69
- if func_type in ("print" , "device_print" , "device_print_uint" ):
70
+ if func_type in ("print" , "device_print" , "device_print_uint" ) and data_type != "bool" :
70
71
for i in range (N ):
71
72
offset = (1 << 31 ) if data_type == "uint32" else 0
72
73
line = f"pid (0, 0, 0) idx ({ i :3} ) x: { i + offset } "
@@ -115,6 +116,10 @@ def test_print(func_type: str, data_type: str, device: str):
115
116
for x in range (x_dim ):
116
117
for y in range (y_dim ):
117
118
expected_lines [f"pid (0, 0, 0) idx ({ x } , { y :2} ): { (x * y_dim + y )} " ] = 1
119
+ elif data_type == "bool" :
120
+ expected_lines ["pid (0, 0, 0) idx ( 0) x: 0" ] = 1
121
+ for i in range (1 , N ):
122
+ expected_lines [f"pid (0, 0, 0) idx ({ i :3} ) x: 1" ] = 1
118
123
119
124
actual_lines = Counter ()
120
125
for line in outs :
0 commit comments