ArrayRef<Value> vals) {
if (map.isEmpty())
return {};
- assert(map.getNumSymbols() == 0);
+
assert(map.getNumInputs() == vals.size());
SmallVector<Value, 8> res;
res.reserve(map.getNumResults());
auto dims = map.getNumDims();
for (auto e : map.getResults()) {
- auto exprMap = AffineMap::get(dims, 0, e);
+ auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
SmallVector<Value, 4> operands(vals.begin(), vals.end());
canonicalizeMapAndOperands(&exprMap, &operands);
res.push_back(affine_apply(exprMap, operands));
SmallVector<Value, 4> indexedValues;
indexedValues.reserve(nInputs + nOutputs);
+ auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
+ auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
+ if (attr) {
+ auto operand = linalgOp.getOperand(attr.getInt());
+ auto shapedType = operand.getType().template cast<ShapedType>();
+ allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank());
+ for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx)
+ allIvsPlusDims.push_back(b.create<DimOp>(loc, operand, idx));
+ }
+
// TODO: Avoid the loads if the corresponding argument of the
// region has no uses.
// 1.a. Emit load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
auto indexing = makeCanonicalAffineApplies(
- b, loc, linalgOp.getInputIndexingMap(i), allIvs);
+ b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims);
// Passing through IndexedValueType emits the proper load operation.
indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing));
}
// 1.b. Emit load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
auto indexing = makeCanonicalAffineApplies(
- b, loc, linalgOp.getOutputIndexingMap(i), allIvs);
+ b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims);
// Passing through IndexedValueType emits the proper load operation.
indexedValues.push_back(
IndexedValueType(linalgOp.getOutputBuffer(i))(indexing));
SmallVector<Value, 8> outputBuffers;
for (unsigned i = 0; i < nOutputs; ++i) {
indexing.push_back(makeCanonicalAffineApplies(
- b, loc, linalgOp.getOutputIndexingMap(i), allIvs));
+ b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims));
outputBuffers.push_back(linalgOp.getOutputBuffer(i));
}
inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing,
linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
auto maps = llvm::to_vector<8>(
llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
- AffineMap invertedMap = inversePermutation(concatAffineMaps(maps));
+ SmallVector<Value, 8> sizes = getViewSizes(builder, linalgOp);
+ AffineMap map = concatAffineMaps(maps);
+ if (map.getNumSymbols()) {
+ // Ignore symbols for now as they are not supported by inversePermutation.
+ unsigned dims = map.getNumDims();
+ SmallVector<AffineExpr, 8> zeros(
+ map.getNumSymbols(), getAffineConstantExpr(0, map.getContext()));
+ SmallVector<AffineExpr, 8> res;
+ for (auto result : map.getResults())
+ res.push_back(result.replaceDimsAndSymbols({}, zeros));
+
+ map = AffineMap::get(dims, 0, res, map.getContext());
+
+ // Cut off values that would have been applied to symbols
+ sizes.resize(res.size());
+ }
+
+ AffineMap invertedMap = inversePermutation(map);
if (!invertedMap)
return {};
if (invertedMap.isEmpty()) {
}
SmallVector<Value, 4> allIvs;
- auto loopRanges =
- emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap,
- getViewSizes(builder, linalgOp));
+ auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(),
+ invertedMap, sizes);
GenerateLoopNest<LoopTy>::doit(
loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) {
allIvs.append(ivs.begin(), ivs.end());
// CHECKLOOP-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECKLOOP-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
// CHECKLOOP-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
+// CHECKLOOP-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)>
// CHECKPARALLEL-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECKPARALLEL-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECKPARALLEL-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECKPARALLEL-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
// CHECKPARALLEL-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
+// CHECKPARALLEL-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)>
func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+
+#conv_1d_accesses = [
+ affine_map<(m, n)[s0] -> (m + n - s0 floordiv 2)>, // in
+ affine_map<(m, n)[s0] -> (n)>, // filter
+ affine_map<(m, n)[s0] -> (m)> // out
+]
+
+#conv_1d_trait = {
+ args_in = 2,
+ args_out = 1,
+ doc = "C(m) += A(m) * B(n)",
+ indexing_maps = #conv_1d_accesses,
+ library_call = "linalg_conv_1d",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel"],
+ symbol_source = 1
+}
+
+func @conv1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
+ linalg.generic #conv_1d_trait %in, %filter, %out {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ } : memref<?xf32>,
+ memref<?xf32>,
+ memref<?xf32>
+ return
+}
+
+// CHECKLOOP-LABEL: @conv1d
+// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKLOOP: %[[c0:.*]] = constant 0 : index
+// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+// CHECKLOOP: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
+// CHECKLOOP: scf.for %[[b:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[m:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
+// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim2]]]
+// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKLOOP: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+
+// CHECKPARALLEL-LABEL: @conv1d
+// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
+// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%{{.*}}, %{{.*}}) to (%[[dim1]], %[[dim0]]) step ({{.*}}) {
+// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim2]]]
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+
+#conv_2d_accesses = [
+ affine_map<(m, n, m1, n1)[s0, s1] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2)>, // in
+ affine_map<(m, n, m1, n1)[s0, s1] -> (m1, n1)>, // filter
+ affine_map<(m, n, m1, n1)[s0, s1] -> (m, n)> // out
+]
+
+#conv_2d_trait = {
+ args_in = 2,
+ args_out = 1,
+ doc = "C(m,n) += A(m,n) * B(m1,n1)",
+ indexing_maps = #conv_2d_accesses,
+ library_call = "linalg_conv_2d",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"],
+ symbol_source = 1
+}
+
+func @conv2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
+ linalg.generic #conv_2d_trait %in, %filter, %out {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ } : memref<?x?xf32>,
+ memref<?x?xf32>,
+ memref<?x?xf32>
+ return
+}
+
+// CHECKLOOP-LABEL: @conv2d
+// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKLOOP: %[[c0:.*]] = constant 0 : index
+// CHECKLOOP: %[[c1:.*]] = constant 1 : index
+// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+// CHECKLOOP: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
+// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
+// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
+// CHECKLOOP: %[[dim4:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// CHECKLOOP: %[[dim5:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim4]]]
+// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim5]]]
+// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]]] : memref<?x?xf32>
+// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[i2]], %[[i3]]] : memref<?x?xf32>
+// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
+// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: @conv2d
+// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
+// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
+// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step ({{.*}}) {
+// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim4]]]
+// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim5]]]
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[i2]], %[[i3]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
+
+#conv_3d_accesses = [
+ affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2, k + k1 - s2 floordiv 2)>, // in
+ affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m1, n1, k1)>, // filter
+ affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m, n, k)> // out
+]
+
+#conv_3d_trait = {
+ args_in = 2,
+ args_out = 1,
+ doc = "C(m,n,k) += A(m,n,k) * B(m1,n1,k1)",
+ indexing_maps = #conv_3d_accesses,
+ library_call = "linalg_conv_3d",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"],
+ symbol_source = 1
+}
+
+func @conv3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
+ linalg.generic #conv_3d_trait %in, %filter, %out {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ } : memref<?x?x?xf32>,
+ memref<?x?x?xf32>,
+ memref<?x?x?xf32>
+ return
+}
+
+// CHECKLOOP-LABEL: @conv3d
+// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKLOOP: %[[c0:.*]] = constant 0 : index
+// CHECKLOOP: %[[c1:.*]] = constant 1 : index
+// CHECKLOOP: %[[c2:.*]] = constant 2 : index
+// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim4]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim5]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i4:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i5:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} {
+// CHECKLOOP: %[[dim6:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim7:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim8:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim6]]]
+// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim7]]]
+// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
+// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[i3]], %[[i4]], %[[i5]]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
+
+// CHECKPARALLEL-LABEL: @conv3d
+// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
+// CHECKPARALLEL: %[[c2:.*]] = constant 2 : index
+// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]], %[[i4:.*]], %[[i5:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step ({{.*}}) {
+// CHECKPARALLEL: %[[dim6:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim7:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim8:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim6]]]
+// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim7]]]
+// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[i3]], %[[i4]], %[[i5]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
+
+#conv_4d_accesses = [
+ affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2, k + k1 - s2 floordiv 2, l + l1 - s3 floordiv 2)>, // in
+ affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m1, n1, k1, l1)>, // filter
+ affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m, n, k, l)> // out
+]
+
+#conv_4d_trait = {
+ args_in = 2,
+ args_out = 1,
+ doc = "C(m,n,k,l) += A(m,n,k,l) * B(m1,n1,k1,l1)",
+ indexing_maps = #conv_4d_accesses,
+ library_call = "linalg_conv_4d",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"],
+ symbol_source = 1
+}
+
+func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out : memref<?x?x?x?xf32>) -> () {
+ linalg.generic #conv_4d_trait %in, %filter, %out {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ } : memref<?x?x?x?xf32>,
+ memref<?x?x?x?xf32>,
+ memref<?x?x?x?xf32>
+ return
+}
+
+// CHECKLOOP-LABEL: @conv4d
+// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKLOOP: %[[c0:.*]] = constant 0 : index
+// CHECKLOOP: %[[c1:.*]] = constant 1 : index
+// CHECKLOOP: %[[c2:.*]] = constant 2 : index
+// CHECKLOOP: %[[c3:.*]] = constant 3 : index
+// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim3:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim6:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim7:.*]] = dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim4]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim5]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim6]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim7]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i4:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i5:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i6:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i7:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} {
+// CHECKLOOP: %[[dim8:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim9:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim10:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim11:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
+// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim9]]]
+// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim10]]]
+// CHECKLOOP: %[[aff4:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim11]]]
+// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]], %[[aff4]]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[i4]], %[[i5]], %[[i6]], %[[i7]]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
+
+// CHECKPARALLEL-LABEL: @conv4d
+// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
+// CHECKPARALLEL: %[[c2:.*]] = constant 2 : index
+// CHECKPARALLEL: %[[c3:.*]] = constant 3 : index
+// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim6:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim7:.*]] = dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]], %[[i4:.*]], %[[i5:.*]], %[[i6:.*]], %[[i7:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim4]], %[[dim5]], %[[dim6]], %[[dim7]], %[[dim0]], %[[dim1]], %[[dim2]], %[[dim3]]) step ({{.*}}) {
+// CHECKPARALLEL: %[[dim8:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim9:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim10:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim11:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
+// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim9]]]
+// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim10]]]
+// CHECKPARALLEL: %[[aff4:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim11]]]
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]], %[[aff4]]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[i4]], %[[i5]], %[[i6]], %[[i7]]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>