Skip to content

Commit eb3d301

Browse files
committed
fix: Add initialization for maxpool_2d and conv2d.
1 parent 009b9b0 commit eb3d301

File tree

4 files changed

+35
-26
lines changed

4 files changed

+35
-26
lines changed

examples/BuddyLeNet/buddy-lenet-import.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828
from buddy.compiler.frontend import DynamoCompiler
2929
from buddy.compiler.graph import GraphDriver
30-
from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion
31-
from buddy.compiler.ops import tosa, linalg
30+
from buddy.compiler.graph.transform import simply_fuse
31+
from buddy.compiler.ops import linalg
3232
from model import LeNet
3333

3434
# Parse command-line arguments.
@@ -37,7 +37,7 @@
3737
"--output-dir",
3838
type=str,
3939
default="./",
40-
help="Directory to save output files."
40+
help="Directory to save output files.",
4141
)
4242
args = parser.parse_args()
4343

@@ -54,9 +54,7 @@
5454

5555
# Initialize Dynamo Compiler with specific configurations as an importer.
5656
dynamo_compiler = DynamoCompiler(
57-
primary_registry=tosa.ops_registry,
58-
# primary_registry=linalg.ops_registry,
59-
# verbose=True
57+
primary_registry=linalg.ops_registry, verbose=True
6058
)
6159

6260
data = torch.randn([1, 1, 28, 28])

examples/BuddyLeNet/buddy-lenet-main.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
#include <cstdlib>
2222
#include <filesystem>
2323
#include <fstream>
24-
#include <limits>
2524
#include <string>
26-
#include <utility>
2725
#include <vector>
2826

2927
constexpr size_t ParamsSize = 44426;

frontend/Python/graph/type.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ class DeviceType(Enum):
9595

9696
CPU = "cpu"
9797
GPU = "gpu"
98-
UNKNOWN = "unknow"
98+
UNKNOWN = "unknown"

frontend/Python/ops/linalg.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
#
1919
# ===---------------------------------------------------------------------------
2020

21-
from typing import Dict, Tuple, List
21+
from typing import Dict, Tuple
2222

2323
import mlir.ir as ir
2424
from mlir.dialects import tosa, linalg, arith, tensor, math
25-
import copy, array, sys
25+
import copy, array
2626
import numpy
27-
import functools
2827

2928
from ..graph import *
3029
from ..graph.graph import TensorDType
@@ -2394,7 +2393,10 @@ def convolution2d_op(
23942393
out_shape = node.tensor_meta["shape"]
23952394
strides_attr = ir._denseI64ArrayAttr(strides, None)
23962395
dilations_attr = ir._denseI64ArrayAttr(dilations, None)
2397-
conv2d_result = tensor.EmptyOp(out_shape, result_element_type)
2396+
conv2d_result = tensor.EmptyOp(out_shape, result_element_type).result
2397+
f32 = ir.F32Type.get()
2398+
zero = arith.ConstantOp(value=ir.FloatAttr.get(f32, 0.0), result=f32).result
2399+
conv2d_result = linalg.fill(zero, outs=[conv2d_result])
23982400
conv2d_nchw_op = linalg.conv_2d_nchw_fchw(
23992401
input_val,
24002402
filter_val,
@@ -2419,7 +2421,6 @@ def convolution2d_op(
24192421
def maxpool2d_op(
24202422
node: Conv2dOp, symbol_table: Dict[Tuple[str, int], ir.Operation]
24212423
):
2422-
# print(node.kwargs, node.args)
24232424
input_ = node.args[0]
24242425
kernel_size = node.args[1]
24252426
strides = node.args[2]
@@ -2430,22 +2431,34 @@ def maxpool2d_op(
24302431
input_value = symbol_table.get((str(input_), 0))
24312432
kernel_size_value = tensor.EmptyOp(kernel_size, result_element_type)
24322433

2433-
if len(node.args) > 3:
2434-
dilations = node.args[4]
2435-
else:
2436-
dilations = [1, 1]
2437-
24382434
strides_attr = ir._denseI64ArrayAttr(strides, None)
2439-
dilations_attr = ir._denseI64ArrayAttr(dilations, None)
24402435

24412436
result = tensor.EmptyOp(result_shape, result_element_type)
2442-
op = linalg.pooling_nchw_max(
2443-
input_value,
2444-
kernel_size_value,
2445-
outs=[result],
2446-
strides=strides_attr,
2447-
dilations=dilations_attr,
2437+
f32 = ir.F32Type.get()
2438+
2439+
# FIXME: fix this magic value!
2440+
largest = arith.ConstantOp(
2441+
value=ir.FloatAttr.get(f32, numpy.finfo(numpy.float32).min), result=f32
24482442
)
2443+
result = linalg.fill(largest, outs=[result])
2444+
2445+
if len(node.args) > 3:
2446+
dilations = node.args[3]
2447+
dilations_attr = ir._denseI64ArrayAttr(dilations, None)
2448+
op = linalg.pooling_nchw_max(
2449+
input_value,
2450+
kernel_size_value,
2451+
outs=[result],
2452+
strides=strides_attr,
2453+
dilations=dilations_attr,
2454+
)
2455+
else:
2456+
op = linalg.pooling_nchw_max(
2457+
input_value,
2458+
kernel_size_value,
2459+
outs=[result],
2460+
strides=strides_attr,
2461+
)
24492462

24502463
return op
24512464

0 commit comments

Comments
 (0)