Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 10d0c61

Browse files
Nicolas Vasilachetensorflower-gardener
Nicolas Vasilache
authored andcommitted
Add a generic Linalg op
This CL introduces a linalg.generic op to represent generic tensor contraction operations on views. A linalg.generic operation requires a numbers of attributes that are sufficient to emit the computation in scalar form as well as compute the appropriate subviews to enable tiling and fusion. These attributes are very similar to the attributes for existing operations such as linalg.matmul etc and existing operations can be implemented with the generic form. In the future, most existing operations can be implemented using the generic form. This CL starts by splitting out most of the functionality of the linalg::NInputsAndOutputs trait into a ViewTrait that queries the per-instance properties of the op. This allows using the attribute informations. This exposes an ordering of verifiers issue where ViewTrait::verify uses attributes but the verifiers for those attributes have not been run. The desired behavior would be for the verifiers of the attributes specified in the builder to execute first but it is not the case atm. As a consequence, to emit proper error messages and avoid crashing, some of the linalg.generic methods are defensive as such: ``` unsigned getNumInputs() { // This is redundant with the `n_views` attribute verifier but ordering of verifiers // may exhibit cases where we crash instead of emitting an error message. if (!getAttr("n_views") || n_views().getValue().size() != 2) return 0; ``` In pretty-printed form, the specific attributes required for linalg.generic are factored out in an independent dictionary named "_". When parsing its content is flattened and the "_name" is dropped. This allows using aliasing for reducing boilerplate at each linalg.generic invocation while benefiting from the Tablegen'd verifier form for each named attribute in the dictionary. For instance, implementing linalg.matmul in terms of linalg.generic resembles: ``` func @mac(%a: f32, %b: f32, %c: f32) -> f32 { %d = mulf %a, %b: f32 %e = addf %c, %d: f32 return %e: f32 } #matmul_accesses = [ (m, n, k) -> (m, k), (m, n, k) -> (k, n), (m, n, k) -> (m, n) ] #matmul_trait = { doc = "C(m, n) += A(m, k) * B(k, n)", fun = @mac, indexing_maps = #matmul_accesses, library_call = "linalg_matmul", n_views = [2, 1], n_loop_types = [2, 1, 0] } ``` And can be used in multiple places as: ``` linalg.generic #matmul_trait %A, %B, %C [other-attributes] : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32> ``` In the future it would be great to have a mechanism to alias / register a new linalg.op as a pair of linalg.generic, #trait. Also, note that with one could theoretically only specify the `doc` string and parse all the attributes from it. PiperOrigin-RevId: 261338740
1 parent 7c2926b commit 10d0c61

16 files changed

+772
-70
lines changed

include/mlir/AffineOps/AffineOps.td

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- Ops.td - Affine operation definitions ---------------*- tablegen -*-===//
1+
//===- AffineOps.td - Affine operation definitions ---------*- tablegen -*-===//
22
//
33
// Copyright 2019 The MLIR Authors.
44
//
@@ -28,6 +28,8 @@
2828
include "mlir/IR/OpBase.td"
2929
#endif // OP_BASE
3030

31+
include "mlir/AffineOps/AffineOpsBase.td"
32+
3133
def Affine_Dialect : Dialect {
3234
let name = "affine";
3335
let cppNamespace = "";
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===- AffineOpsBase.td - Affine operation definitions -----*- tablegen -*-===//
2+
//
3+
// Copyright 2019 The MLIR Authors.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
// =============================================================================
17+
//
18+
// Defines base support for MLIR affine operations.
19+
//
20+
//===----------------------------------------------------------------------===//
21+
22+
#ifdef AFFINE_OPS_BASE
23+
#else
24+
#define AFFINE_OPS_BASE
25+
26+
#ifdef OP_BASE
27+
#else
28+
include "mlir/IR/OpBase.td"
29+
#endif // OP_BASE
30+
31+
// Attributes containing affine maps.
32+
def AffineMapAttr : Attr<
33+
CPred<"$_self.isa<AffineMapAttr>()">, "AffineMap attribute"> {
34+
let storageType = [{ AffineMapAttr }];
35+
let returnType = [{ AffineMap }];
36+
let constBuilderCall = "$_builder.getAffineMapAttr($0)";
37+
}
38+
39+
def AffineMapArrayAttr : TypedArrayAttrBase<AffineMapAttr,
40+
"AffineMap array attribute"> {
41+
let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)";
42+
}
43+
44+
#endif // AFFINE_OPS_BASE

include/mlir/EDSC/Intrinsics.h

+1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ using alloc = ValueBuilder<AllocOp>;
190190
using affine_apply = ValueBuilder<AffineApplyOp>;
191191
using affine_load = ValueBuilder<AffineLoadOp>;
192192
using affine_store = OperationBuilder<AffineStoreOp>;
193+
using call = OperationBuilder<mlir::CallOp>;
193194
using constant_float = ValueBuilder<ConstantFloatOp>;
194195
using constant_index = ValueBuilder<ConstantIndexOp>;
195196
using constant_int = ValueBuilder<ConstantIntOp>;

include/mlir/IR/AffineMap.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,12 @@ AffineMap simplifyAffineMap(AffineMap map);
153153

154154
/// Returns a map of codomain to domain dimensions such that the first codomain
155155
/// dimension for a particular domain dimension is selected.
156-
/// Returns an empty map if the input map is empty.
156+
/// Returns an empty map if the input map is empty or if `map` is not invertible
157+
/// (i.e. `map` does not contain a subset that is a permutation of full domain
158+
/// rank).
157159
///
158160
/// Prerequisites:
159-
/// 1. `map` must contain a subset that is a permutation of full domain rank.
160-
/// 2. `map` has no symbols.
161+
/// 1. `map` has no symbols.
161162
///
162163
/// Example 1:
163164
///

include/mlir/IR/Builders.h

+2
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,11 @@ class Builder {
130130
FloatAttr getF16FloatAttr(float value);
131131
FloatAttr getF32FloatAttr(float value);
132132
FloatAttr getF64FloatAttr(double value);
133+
133134
IntegerAttr getI32IntegerAttr(int32_t value);
134135
IntegerAttr getI64IntegerAttr(int64_t value);
135136

137+
ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
136138
ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
137139
ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
138140
ArrayAttr getF32ArrayAttr(ArrayRef<float> values);

include/mlir/Linalg/IR/LinalgLibraryOps.td

+178-16
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020
//
2121
//===----------------------------------------------------------------------===//
2222

23-
include "mlir/Linalg/IR/LinalgBase.td"
24-
2523
#ifdef LINALG_LIBRARY_OPS
2624
#else
2725
#define LINALG_LIBRARY_OPS
2826

27+
include "mlir/AffineOps/AffineOpsBase.td"
28+
include "mlir/Linalg/IR/LinalgBase.td"
29+
2930
class LinalgParametricNativeOpTrait<string prop, string parameters> :
3031
NativeOpTrait<"linalg::" # prop # parameters>
3132
{}
@@ -65,29 +66,28 @@ class ViewRanks<list<int> ranks> :
6566
LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
6667
{}
6768

69+
def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
70+
6871
// Base Tablegen class for Linalg ops.
6972
// Linalg ops that correspond to library calls operate on linalg::View as their
7073
// first operands. These may be optionally followed by non-view operands
7174
// depending on the specific Linalg op.
72-
class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
73-
: Op<Linalg_Dialect, mnemonic, props> {
75+
class LinalgLibraryBase_Op<string mnemonic, list<OpTrait> props>
76+
: Op<Linalg_Dialect, mnemonic, !listconcat(props, [ViewTraits])> {
7477
let parser = [{ return parseLinalgLibraryOp(parser, result); }];
7578
let printer = [{ printLinalgLibraryOp(p, *this); }];
79+
}
7680

77-
let extraClassDeclaration = [{
78-
static StringRef getLibraryCallName() {
81+
class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
82+
: LinalgLibraryBase_Op<mnemonic, props> {
83+
84+
code classDeclaration = [{
85+
StringRef getLibraryCallName() {
7986
return "linalg_}] # mnemonic # [{";
8087
}
8188
}];
8289
}
8390

84-
def AffineMapAttr : Attr<
85-
CPred<"$_self.isa<AffineMapAttr>()">, "AffineMap attribute"> {
86-
let storageType = [{ AffineMapAttr }];
87-
let returnType = [{ AffineMap }];
88-
let constBuilderCall = "$_builder.getAffineMapAttr($0)";
89-
}
90-
9191
////////////////////////////////////////////////////////////////////////////////
9292
// Concrete Linalg ops.
9393
////////////////////////////////////////////////////////////////////////////////
@@ -138,7 +138,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
138138
return build(
139139
builder, result, input, output, AffineMapAttr(), AffineMapAttr());
140140
}]>];
141-
let extraClassDeclaration = [{
141+
let extraClassDeclaration = classDeclaration # [{
142142
unsigned getNumParallelLoops() {
143143
auto *view = *(getOperands().begin());
144144
return view->getType().cast<ViewType>().getRank();
@@ -151,7 +151,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
151151

152152
def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
153153
let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
154-
let extraClassDeclaration = [{
154+
let extraClassDeclaration = classDeclaration # [{
155155
unsigned getNumParallelLoops() {
156156
auto *view = *(getOperands().begin());
157157
return view->getType().cast<ViewType>().getRank();
@@ -170,20 +170,23 @@ def DotOp : LinalgLibrary_Op<"dot",
170170
NLoopTypes<0, 1, 0>,
171171
ViewRanks<[1, 1, 0]>]> {
172172
let arguments = (ins View, View, View);
173+
let extraClassDeclaration = classDeclaration;
173174
}
174175

175176
def MatvecOp : LinalgLibrary_Op<"matvec",
176177
[NInputsAndOutputs<2, 1>,
177178
NLoopTypes<1, 1, 0>,
178179
ViewRanks<[2, 1, 1]>]> {
179180
let arguments = (ins View, View, View);
181+
let extraClassDeclaration = classDeclaration;
180182
}
181183

182184
def MatmulOp : LinalgLibrary_Op<"matmul",
183185
[NInputsAndOutputs<2, 1>,
184186
NLoopTypes<2, 1, 0>,
185187
ViewRanks<[2, 2, 2]>]> {
186188
let arguments = (ins View, View, View);
189+
let extraClassDeclaration = classDeclaration;
187190
}
188191

189192
def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
@@ -208,7 +211,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
208211
let arguments = (ins View:$filter, View:$input, View:$output,
209212
OptionalAttr<I64ArrayAttr>:$strides,
210213
OptionalAttr<I64ArrayAttr>:$dilations);
211-
let extraClassDeclaration = [{
214+
let extraClassDeclaration = classDeclaration # [{
212215
// TODO(ntv) extend to support more than 1 dimensions and potentially
213216
// grouping too.
214217
unsigned getNumBatchDimensions() { return 1; }
@@ -248,4 +251,163 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
248251
let verifier = [{ return ::verify(*this); }];
249252
}
250253

254+
def GenericOp : LinalgLibraryBase_Op<"generic", []> {
255+
let description = [{
256+
Generic Linalg op form where the key properties of the computation are
257+
specified as attributes. In pretty form, a linalg.generic op is written as:
258+
259+
```
260+
linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
261+
!linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
262+
```
263+
264+
Where #trait_attributes is an alias of a dictionary attribute containing:
265+
- doc [optional]: a documentation string
266+
- fun: a SymbolRefAttr that must resolve to an existing function symbol.
267+
To support inplace updates in a generic fashion, the signature of the
268+
function must be:
269+
```
270+
fun([input views element types], [output views element types])
271+
-> ([output views element types])
272+
```
273+
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
274+
and output view. Such AffineMapAttr specifies the mapping between the
275+
loops and the indexing within each view.
276+
- library_call [optional]: a StringAttr containing the name of an
277+
external library function that the linalg.generic operation maps to.
278+
The external library is assumed to be dynamically linked and no strong
279+
compile-time guarantees are provided. In the absence of such a library
280+
call, linalg.generic will always lower to loops.
281+
- n_loops: a triple of I64Attr representing the number of enclosing
282+
[parallel, reduction, window] loops respectively.
283+
- n_views: a pair of I64Attr representing the number of input (readonly)
284+
and output (readwrite) views.
285+
286+
Example:
287+
Defining a #matmul_trait attribute in MLIR can be done as follows:
288+
```
289+
func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
290+
%d = mulf %a, %b: f32
291+
%e = addf %c, %d: f32
292+
return %e: f32
293+
}
294+
#matmul_accesses = [
295+
(m, n, k) -> (m, k),
296+
(m, n, k) -> (k, n),
297+
(m, n, k) -> (m, n)
298+
]
299+
#matmul_trait = {
300+
doc = "C(m, n) += A(m, k) * B(k, n)",
301+
fun = @fma,
302+
indexing_maps = #matmul_accesses,
303+
library_call = "linalg_matmul",
304+
n_views = [2, 1],
305+
n_loop_types = [2, 1, 0]
306+
}
307+
```
308+
309+
And can be reused in multiple places as:
310+
```
311+
linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
312+
!linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
313+
```
314+
315+
This may lower to either:
316+
```
317+
call @linalg_matmul(%A, %B, %C) :
318+
(!linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>)
319+
-> ()
320+
```
321+
322+
or IR resembling:
323+
```
324+
loop.for %m = %c0 to %M step %c1 {
325+
loop.for %n = %c0 to %N step %c1 {
326+
loop.for %k = %c0 to %K step %c1 {
327+
%a = linalg.load %A[%m, %k] : !linalg.view<?x?xf32>
328+
%b = linalg.load %B[%k, %n] : !linalg.view<?x?xf32>
329+
%c = linalg.load %C[%m, %n] : !linalg.view<?x?xf32>
330+
%d = call @mac(%a, %b, %c) : (f32, f32, f32) -> (f32)
331+
linalg.store %d, %C[%m, %n] : !linalg.view<?x?x?xf32>
332+
}
333+
}
334+
}
335+
```
336+
}];
337+
let arguments = (ins Variadic<View>:$views,
338+
SymbolRefAttr:$fun,
339+
AffineMapArrayAttr:$indexing_maps,
340+
I64ArrayAttr:$n_loop_types,
341+
I64ArrayAttr:$n_views,
342+
OptionalAttr<StrAttr>:$doc,
343+
OptionalAttr<StrAttr>:$library_call);
344+
let extraClassDeclaration = [{
345+
SmallVector<StringRef, 8> linalgTraitAttrNames() {
346+
return SmallVector<StringRef, 8>{
347+
"doc", "fun", "indexing_maps", "library_call", "n_loop_types", "n_views"
348+
};
349+
}
350+
unsigned getNumInputs() {
351+
if (!getAttr("n_views") || n_views().getValue().size() != 2)
352+
return 0;
353+
auto val = n_views().getValue()[0].cast<IntegerAttr>().getValue();
354+
assert(val.getSExtValue() >= 0);
355+
return val.getZExtValue();
356+
}
357+
unsigned getNumOutputs() {
358+
if (!getAttr("n_views") || n_views().getValue().size() != 2)
359+
return 0;
360+
auto val = n_views().getValue()[1].cast<IntegerAttr>().getValue();
361+
assert(val.getSExtValue() >= 0);
362+
return val.getZExtValue();
363+
}
364+
unsigned getNumParallelLoops() {
365+
if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
366+
return 0;
367+
auto val = n_loop_types().getValue()[0].cast<IntegerAttr>().getValue();
368+
assert(val.getSExtValue() >= 0);
369+
return val.getZExtValue();
370+
}
371+
unsigned getNumReductionLoops() {
372+
if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
373+
return 0;
374+
auto val = n_loop_types().getValue()[1].cast<IntegerAttr>().getValue();
375+
assert(val.getSExtValue() >= 0);
376+
return val.getZExtValue();
377+
}
378+
unsigned getNumWindowLoops() {
379+
if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
380+
return 0;
381+
auto val = n_loop_types().getValue()[2].cast<IntegerAttr>().getValue();
382+
assert(val.getSExtValue() >= 0);
383+
return val.getZExtValue();
384+
}
385+
unsigned getNumLoops() {
386+
return getNumParallelLoops() + getNumReductionLoops() +
387+
getNumWindowLoops();
388+
}
389+
StringRef getFunName() {
390+
return fun();
391+
}
392+
StringRef getLibraryCallName() {
393+
return library_call().hasValue() ? library_call().getValue() : "";
394+
}
395+
AffineMap getIndexingMap(unsigned i) {
396+
assert(i < getNumInputsAndOutputs());
397+
return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
398+
}
399+
AffineMap getInputIndexingMap(unsigned i) {
400+
assert(i < getNumInputs());
401+
return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
402+
}
403+
AffineMap getOutputIndexingMap(unsigned i) {
404+
assert(i < getNumOutputs());
405+
return indexing_maps().getValue()[i + getNumInputs()]
406+
.cast<AffineMapAttr>().getValue();
407+
}
408+
}];
409+
let printer = [{ return ::print(p, *this); }];
410+
let verifier = [{ return ::verify(*this); }];
411+
let parser = [{ return ::parse$cppClass(parser, result); }];
412+
}
251413
#endif // LINALG_LIBRARY_OPS

include/mlir/Linalg/IR/LinalgOps.h

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#ifndef MLIR_LINALG_LINALGOPS_H_
1919
#define MLIR_LINALG_LINALGOPS_H_
2020

21+
#include "mlir/IR/AffineMap.h"
2122
#include "mlir/IR/Builders.h"
2223
#include "mlir/IR/OpDefinition.h"
2324
#include "mlir/Linalg/IR/LinalgTraits.h"

0 commit comments

Comments
 (0)