32
32
F64Type ,
33
33
)
34
34
35
- # TODO: Have a way upstream to check if a floating point type .
36
- FLOAT_TYPES_ASM = {
37
- "bf16" ,
38
- "f16" ,
39
- "f32" ,
40
- "f64" ,
35
+ # TODO: Use FloatType from upstream when available .
36
+ FLOAT_BITWIDTHS = {
37
+ "bf16" : 16 ,
38
+ "f16" : 16 ,
39
+ "f32" : 32 ,
40
+ "f64" : 64 ,
41
41
# TODO: FP8 types.
42
42
}
43
43
@@ -87,28 +87,54 @@ def __init__(
87
87
88
88
class _ScalarBuilder :
89
89
def is_floating_point_type (self , t : IrType ) -> bool :
90
- return str (t ) in FLOAT_TYPES_ASM
90
+ # TODO: Use FloatType from upstream when available.
91
+ return str (t ) in FLOAT_BITWIDTHS
91
92
92
93
def is_integer_type (self , t : IrType ) -> bool :
93
94
return IntegerType .isinstance (t )
94
95
95
96
def is_index_type (self , t : IrType ) -> bool :
96
97
return IndexType .isinstance (t )
97
98
98
- def promote (self , value : Value , to_type : IrType ) -> Value :
99
- value_type = value .type
99
+ def get_typeclass (self , t : IrType , index_same_as_integer = False ) -> str :
100
+ # If this is a vector type, get the element type.
101
+ if isinstance (t , VectorType ):
102
+ t = t .element_type
103
+ if self .is_floating_point_type (t ):
104
+ return "float"
105
+ if self .is_integer_type (t ):
106
+ return "integer"
107
+ if self .is_index_type (t ):
108
+ return "integer" if index_same_as_integer else "index"
109
+ raise CodegenError (f"Unknown typeclass for type `{ t } `" )
110
+
111
+ def get_float_bitwidth (self , t : IrType ) -> int :
112
+ # If this is a vector type, get the element type.
113
+ if isinstance (t , VectorType ):
114
+ t = t .element_type
115
+ return FLOAT_BITWIDTHS [str (t )]
116
+
117
+ def to_dtype (self , value : IRProxyValue , dtype : IrType ) -> IRProxyValue :
118
+ value_type = value .ir_value .type
119
+ # Create a vector type for dtype if value is a vector.
120
+ to_type = dtype
121
+ if isinstance (value_type , VectorType ):
122
+ to_type = VectorType .get (value_type .shape , dtype )
123
+
100
124
# Short-circuit if already the right type.
101
125
if value_type == to_type :
102
126
return value
103
127
104
- attr_name = f"promote_{ value_type } _to_{ to_type } "
128
+ value_typeclass = self .get_typeclass (value_type )
129
+ to_typeclass = self .get_typeclass (dtype )
130
+ attr_name = f"to_dtype_{ value_typeclass } _to_{ to_typeclass } "
105
131
try :
106
132
handler = getattr (self , attr_name )
107
133
except AttributeError :
108
134
raise CodegenError (
109
135
f"No implemented path to implicitly promote scalar `{ value_type } ` to `{ to_type } ` (tried '{ attr_name } ')"
110
136
)
111
- return handler (value , to_type )
137
+ return IRProxyValue ( handler (value . ir_value , to_type ) )
112
138
113
139
def constant_attr (self , val : int | float , element_type : IrType ) -> Attribute :
114
140
if self .is_integer_type (element_type ) or self .is_index_type (element_type ):
@@ -153,7 +179,7 @@ def binary_arithmetic(
153
179
f"Cannot perform binary arithmetic operation '{ op } ' between { lhs_ir_type } and { rhs_ir_type } due to element type mismatch"
154
180
)
155
181
156
- typeclass = "float" if self .is_floating_point_type (lhs_ir_type ) else "integer"
182
+ typeclass = self .get_typeclass (lhs_ir_type , True )
157
183
attr_name = f"binary_{ op } _{ typeclass } "
158
184
try :
159
185
handler = getattr (self , attr_name )
@@ -176,9 +202,7 @@ def binary_vector_arithmetic(
176
202
f"Cannot perform binary arithmetic operation '{ op } ' between { lhs_ir .type } and { rhs_ir .type } due to element type mismatch"
177
203
)
178
204
179
- typeclass = (
180
- "float" if self .is_floating_point_type (lhs_element_type ) else "integer"
181
- )
205
+ typeclass = self .get_typeclass (lhs_element_type , True )
182
206
attr_name = f"binary_{ op } _{ typeclass } "
183
207
try :
184
208
handler = getattr (self , attr_name )
@@ -190,7 +214,7 @@ def binary_vector_arithmetic(
190
214
191
215
def unary_arithmetic (self , op : str , val : IRProxyValue ) -> IRProxyValue :
192
216
val_ir_type = val .ir_value .type
193
- typeclass = "float" if self .is_floating_point_type (val_ir_type ) else "integer"
217
+ typeclass = self .get_typeclass (val_ir_type , True )
194
218
attr_name = f"unary_{ op } _{ typeclass } "
195
219
try :
196
220
handler = getattr (self , attr_name )
@@ -203,9 +227,7 @@ def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
203
227
def unary_vector_arithmetic (self , op : str , val : IRProxyValue ) -> IRProxyValue :
204
228
val_ir = val .ir_value
205
229
val_element_type = VectorType (val_ir .type ).element_type
206
- typeclass = (
207
- "float" if self .is_floating_point_type (val_element_type ) else "integer"
208
- )
230
+ typeclass = self .get_typeclass (val_element_type , True )
209
231
attr_name = f"unary_{ op } _{ typeclass } "
210
232
try :
211
233
handler = getattr (self , attr_name )
@@ -217,10 +239,33 @@ def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
217
239
218
240
### Specializations
219
241
220
- def promote_index_to_f32 (self , value : Value , to_type : IrType ) -> Value :
221
- i32_type = IntegerType .get_signless (32 )
222
- i32 = arith_d .index_cast (i32_type , value )
223
- return arith_d .sitofp (to_type , i32 )
242
+ # Casting
243
+ def to_dtype_index_to_integer (self , value : Value , to_type : IrType ) -> Value :
244
+ return arith_d .index_cast (to_type , value )
245
+
246
+ def to_dtype_index_to_float (self , value : Value , to_type : IrType ) -> Value :
247
+ # Cast index to integer, and then ask for a integer to float cast.
248
+ # TODO: I don't really know how to query the machine bitwidth here,
249
+ # so using 64.
250
+ casted_to_int = arith_d .index_cast (IntegerType .get_signless (64 ), value )
251
+ return self .to_dtype (IRProxyValue (casted_to_int ), to_type ).ir_value
252
+
253
+ def to_dtype_integer_to_float (self , value : Value , to_type : IrType ) -> Value :
254
+ # sitofp
255
+ casted_to_float = arith_d .sitofp (to_type , value )
256
+ return self .to_dtype (IRProxyValue (casted_to_float ), to_type ).ir_value
257
+
258
+ def to_dtype_float_to_float (self , value : Value , to_type : IrType ) -> Value :
259
+ # Check bitwidth to determine if we need to extend or narrow
260
+ from_type = value .type
261
+ from_bitwidth = self .get_float_bitwidth (from_type )
262
+ to_bitwidth = self .get_float_bitwidth (to_type )
263
+ if from_bitwidth < to_bitwidth :
264
+ return arith_d .extf (to_type , value )
265
+ elif from_bitwidth > to_bitwidth :
266
+ return arith_d .truncf (to_type , value )
267
+ else :
268
+ raise CodegenError (f"NYI: Cast from { from_type } to { to_type } " )
224
269
225
270
# Binary integer/integer arithmetic.
226
271
def binary_add_integer (self , lhs : IRProxyValue , rhs : IRProxyValue ) -> IRProxyValue :
0 commit comments