18
18
#
19
19
# ===---------------------------------------------------------------------------
20
20
21
- from typing import Dict , Tuple , List
21
+ from typing import Dict , Tuple
22
22
23
23
import mlir .ir as ir
24
24
from mlir .dialects import tosa , linalg , arith , tensor , math
25
- import copy , array , sys
25
+ import copy , array
26
26
import numpy
27
- import functools
28
27
29
28
from ..graph import *
30
29
from ..graph .graph import TensorDType
@@ -2394,7 +2393,10 @@ def convolution2d_op(
2394
2393
out_shape = node .tensor_meta ["shape" ]
2395
2394
strides_attr = ir ._denseI64ArrayAttr (strides , None )
2396
2395
dilations_attr = ir ._denseI64ArrayAttr (dilations , None )
2397
- conv2d_result = tensor .EmptyOp (out_shape , result_element_type )
2396
+ conv2d_result = tensor .EmptyOp (out_shape , result_element_type ).result
2397
+ f32 = ir .F32Type .get ()
2398
+ zero = arith .ConstantOp (value = ir .FloatAttr .get (f32 , 0.0 ), result = f32 ).result
2399
+ conv2d_result = linalg .fill (zero , outs = [conv2d_result ])
2398
2400
conv2d_nchw_op = linalg .conv_2d_nchw_fchw (
2399
2401
input_val ,
2400
2402
filter_val ,
@@ -2419,7 +2421,6 @@ def convolution2d_op(
2419
2421
def maxpool2d_op (
2420
2422
node : Conv2dOp , symbol_table : Dict [Tuple [str , int ], ir .Operation ]
2421
2423
):
2422
- # print(node.kwargs, node.args)
2423
2424
input_ = node .args [0 ]
2424
2425
kernel_size = node .args [1 ]
2425
2426
strides = node .args [2 ]
@@ -2430,22 +2431,34 @@ def maxpool2d_op(
2430
2431
input_value = symbol_table .get ((str (input_ ), 0 ))
2431
2432
kernel_size_value = tensor .EmptyOp (kernel_size , result_element_type )
2432
2433
2433
- if len (node .args ) > 3 :
2434
- dilations = node .args [4 ]
2435
- else :
2436
- dilations = [1 , 1 ]
2437
-
2438
2434
strides_attr = ir ._denseI64ArrayAttr (strides , None )
2439
- dilations_attr = ir ._denseI64ArrayAttr (dilations , None )
2440
2435
2441
2436
result = tensor .EmptyOp (result_shape , result_element_type )
2442
- op = linalg .pooling_nchw_max (
2443
- input_value ,
2444
- kernel_size_value ,
2445
- outs = [result ],
2446
- strides = strides_attr ,
2447
- dilations = dilations_attr ,
2437
+ f32 = ir .F32Type .get ()
2438
+
2439
+ # FIXME: fix this magic value!
2440
+ largest = arith .ConstantOp (
2441
+ value = ir .FloatAttr .get (f32 , numpy .finfo (numpy .float32 ).min ), result = f32
2448
2442
)
2443
+ result = linalg .fill (largest , outs = [result ])
2444
+
2445
+ if len (node .args ) > 3 :
2446
+ dilations = node .args [3 ]
2447
+ dilations_attr = ir ._denseI64ArrayAttr (dilations , None )
2448
+ op = linalg .pooling_nchw_max (
2449
+ input_value ,
2450
+ kernel_size_value ,
2451
+ outs = [result ],
2452
+ strides = strides_attr ,
2453
+ dilations = dilations_attr ,
2454
+ )
2455
+ else :
2456
+ op = linalg .pooling_nchw_max (
2457
+ input_value ,
2458
+ kernel_size_value ,
2459
+ outs = [result ],
2460
+ strides = strides_attr ,
2461
+ )
2449
2462
2450
2463
return op
2451
2464
0 commit comments