class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
let arguments = (ins Variadic<AnyStridedMemRef>:$views,
AffineMapArrayAttr:$indexing_maps,
- I64ArrayAttr:$n_loop_types,
+ ArrayAttr:$iterator_types,
I64ArrayAttr:$n_views,
OptionalAttr<StrAttr>:$doc,
OptionalAttr<FlatSymbolRefAttr>:$fun,
let extraClassDeclaration = [{
SmallVector<StringRef, 8> linalgTraitAttrNames() {
return SmallVector<StringRef, 8>{
- "doc", "fun", "indexing_maps", "library_call", "n_loop_types", "n_views"
+ "doc", "fun", "indexing_maps", "library_call", "iterator_types", "n_views"
};
}
unsigned getNumInputs() {
return val.getZExtValue();
}
unsigned getNumParallelLoops() {
- if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+ if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
return 0;
- auto val = n_loop_types().getValue()[0].cast<IntegerAttr>().getValue();
- assert(val.getSExtValue() >= 0);
- return val.getZExtValue();
+ unsigned nPar = 0;
+ for (auto ty : iterator_types()) {
+ if (ty.cast<StringAttr>().getValue() == "parallel")
+ nPar++;
+ }
+ return nPar;
}
unsigned getNumReductionLoops() {
- if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+ if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
return 0;
- auto val = n_loop_types().getValue()[1].cast<IntegerAttr>().getValue();
- assert(val.getSExtValue() >= 0);
- return val.getZExtValue();
- }
- unsigned getNumWindowLoops() {
- if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+ unsigned nRed = 0;
+ for (auto ty : iterator_types()) {
+ if (ty.cast<StringAttr>().getValue() == "reduction")
+ nRed++;
+ }
+ return nRed;
+ }
+ unsigned getNumWindowLoops() {
+ if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
return 0;
- auto val = n_loop_types().getValue()[2].cast<IntegerAttr>().getValue();
- assert(val.getSExtValue() >= 0);
- return val.getZExtValue();
- }
+ unsigned nWin = 0;
+ for (auto ty : iterator_types()) {
+ if (ty.cast<StringAttr>().getValue() == "window")
+ nWin++;
+ }
+ return nWin;
+ }
unsigned getNumLoops() {
return getNumParallelLoops() + getNumReductionLoops() +
getNumWindowLoops();
The external library is assumed to be dynamically linked and no strong
compile-time guarantees are provided. In the absence of such a library
call, linalg.generic will always lower to loops.
- - n_loops: a triple of I64Attr representing the number of enclosing
- [parallel, reduction, window] loops respectively.
+ - iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
+ the list represents and iterator of one of the following types:
+ parallel, reduction, window
- n_views: a pair of I64Attr representing the number of input (readonly)
and output (readwrite) views.
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
n_views = [2, 1],
- n_loop_types = [2, 1, 0]
+ iterator_types = ["parallel", "parallel", "reduction"]
}
```
maps to. The external library is assumed to be dynamically linked and
no strong compile-time guarantees are provided. In the absence of such
a library call, linalg.indexed_generic will always lower to loops.
- - n_loops: a triple of I64Attr representing the number of enclosing
- [parallel, reduction, window] loops respectively.
+ - iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
+ the list represents and iterator of one of the following types:
+ parallel, reduction, window
- n_views: a pair of I64Attr representing the number of input (readonly)
and output (readwrite) views.
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
n_views = [2, 1],
- n_loop_types = [2, 1, 0]
+ iterator_types = ["parallel", "parallel", "reduction"]
}
```
#id_2d = (i, j) -> (i, j)
#pointwise_2d_trait = {
indexing_maps = [#id_2d, #id_2d, #id_2d],
- n_loop_types = [2, 0, 0],
+ iterator_types = ["parallel", "parallel"],
n_views = [2, 1]
}
func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>) {
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [1, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0: memref<f32>
}
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [1, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0, %arg0, %arg0: memref<f32>, memref<f32>, memref<f32>
}
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [1, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0, %arg0: memref<f32>, memref<f32>
}
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0: memref<f32>
}
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0: memref<f32>
}
fun = @foo,
indexing_maps = [ ()[N] -> (0) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0]
+ iterator_types = ["parallel"]
} %arg0: memref<i32>
}
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0]
+ iterator_types = ["parallel"]
} %arg0: memref<i32>
}
fun = @foo,
indexing_maps = [ () -> (1) ],
n_views = [0, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0: memref<i32>
}
fun = @foo,
indexing_maps = [ () -> (0, 0) ],
n_views = [0, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0: memref<?xf32, (i)[off]->(off + i)>
}
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0: memref<?xf32, (i)[off]->(off + i)>
}
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0: memref<?xf32, (i)[off]->(off + i)>
}
(i, j) -> (i + j)
],
n_views = [1, 1],
- n_loop_types = [2, 0, 0]
+ iterator_types = ["parallel","parallel"]
} %arg0, %arg1: memref<?xf32, (i)[off]->(off + i)>, memref<?xf32, (i)[off]->(off + i)>
}
linalg.generic {
indexing_maps = [ () -> (0) ],
n_views = [1, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0, %arg0 {
^bb1:
^bb2:
linalg.generic {
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0 {
^bb:
}: memref<f32>
linalg.generic {
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
- n_loop_types = [0, 0, 0]
+ iterator_types = []
} %arg0 {
^bb(%i: i1):
}: memref<f32>
linalg.indexed_generic {
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0]
+ iterator_types = ["parallel"]
} %arg0 {
^bb(%f: f32):
}: memref<f32>
linalg.indexed_generic {
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0]
+ iterator_types = ["parallel"]
} %arg0 {
^bb(%i: f64, %f: f32):
}: memref<f32>
linalg.indexed_generic {
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0]
+ iterator_types = ["parallel"]
} %arg0 {
^bb(%i: index, %f: i1):
}: memref<f32>
linalg.indexed_generic {
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0],
+ iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<f32>
}
// expected-error @+1 {{op expected fun argument 0 to be of IndexType}}
linalg.indexed_generic {
n_views = [0, 1],
- n_loop_types = [1, 0, 0],
+ iterator_types = ["parallel"],
indexing_maps = [ (i) -> (i) ],
fun = @foo
} %arg0 : memref<f32>
linalg.indexed_generic {
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0],
+ iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<f32>
}
linalg.indexed_generic {
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0],
+ iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<f32>
}
linalg.indexed_generic {
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0],
+ iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<i32>
}
linalg.generic {
indexing_maps = [ (i) -> (i) ],
n_views = [0, 1],
- n_loop_types = [1, 0, 0]
+ iterator_types = ["parallel"]
} %arg0 {
^bb(%i: f32):
%0 = constant 0: i1
]
#matmul_trait = {
n_views = [2, 1],
- n_loop_types = [2, 1, 0],
+ iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "some_external_function_name_for_vector_outerproduct_matmul"
}
]
#trait = {
n_views = [1, 2],
- n_loop_types = [3, 0, 0],
+ iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
fun = @foo,
library_call = "some_external_function_name_1",
#trait2 = {
n_views = [1, 2],
- n_loop_types = [3, 0, 0],
+ iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_function_name_2",
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
}
#trait3 = {
n_views = [1, 2],
- n_loop_types = [3, 0, 0],
+ iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
fun = @indexed_foo,
library_call = "some_external_function_name_1",
#trait4 = {
n_views = [1, 2],
- n_loop_types = [3, 0, 0],
+ iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_function_name_2",
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"
#trait = {
indexing_maps = #accesses,
n_views = [1, 1],
- n_loop_types = [3, 0, 0],
+ iterator_types = ["parallel", "parallel", "parallel"],
fun = @foo,
library_call = "some_external_function_name_1"
}
}
// CHECK-LABEL: func @foo
// CHECK-LABEL: func @generic
-// CHECK: linalg.generic {fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], library_call = "some_external_function_name_1", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
+// CHECK: linalg.generic {fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1", n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
#trait2 = {
indexing_maps = #accesses,
n_views = [1, 1],
- n_loop_types = [3, 0, 0],
+ iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_2"
}
func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
return
}
// CHECK-LABEL: func @generic_region
-// CHECK: linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], library_call = "some_external_function_name_2", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {
+// CHECK: linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
return
}
// CHECK-LABEL: func @indexed_generic
-// CHECK: linalg.indexed_generic {indexing_maps = [#{{.*}}, #{{.*}}], library_call = "some_external_function_name_2", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {
+// CHECK: linalg.indexed_generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
#id_2d = (i, j) -> (i, j)
#pointwise_2d_trait = {
indexing_maps = [#id_2d, #id_2d, #id_2d],
- n_loop_types = [2, 0, 0],
+ iterator_types = ["parallel", "parallel"],
n_views = [2, 1]
}
(i, j) -> (i, j)
],
n_views = [1, 1],
- n_loop_types = [2, 0, 0]
+ iterator_types = ["parallel", "parallel"]
}
func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,