[mlir][Linalg] Conv {1,2,3}D ops defined with TC syntax
authorJakub Lichman <limo@google.com>
Fri, 31 Jul 2020 11:18:11 +0000 (13:18 +0200)
committerAlex Zinenko <zinenko@google.com>
Fri, 31 Jul 2020 11:20:17 +0000 (13:20 +0200)
Replaced definition of named ND ConvOps with tensor comprehension
syntax which reduces boilerplate code significantly. Furthermore,
new ops to support TF convolutions added (without strides and dilations).

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D84628

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/loops.mlir

index 056f072..27d4330 100644 (file)
@@ -17,3 +17,55 @@ ods_def<BatchMatmulOp>:
 def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
 }
+
+ods_def<ConvWOp>:
+def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
+  O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw)));
+}
+
+ods_def<ConvNWCOp>:
+def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
+  O(n, w, f) = std_addf(O(n, w, f),
+    std_mulf(I(n, w + kw, c), K(f, kw, c)));
+}
+
+ods_def<ConvNCWOp>:
+def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
+  O(n, f, w) = std_addf(O(n, f, w),
+    std_mulf(I(n, c, w + kw), K(f, c, kw)));
+}
+
+ods_def<ConvHWOp>:
+def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
+  O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw)));
+}
+
+ods_def<ConvNHWCOp>:
+def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
+  O(n, h, w, f) = std_addf(O(n, h, w, f),
+    std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
+}
+
+ods_def<ConvNCHWOp>:
+def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
+  O(n, f, h, w) = std_addf(O(n, f, h, w),
+    std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
+}
+
+ods_def<ConvDHWOp>:
+def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
+  O(d, h, w) = std_addf(O(d, h, w),
+    std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
+}
+
+ods_def<ConvNDHWCOp>:
+def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
+  O(n, d, h, w, f) = std_addf(O(n, d, h, w, f),
+    std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
+}
+
+ods_def<ConvNCDHWOp>:
+def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
+  O(n, f, d, h, w) = std_addf(O(n, f, d, h, w),
+    std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
+}
\ No newline at end of file
index 75e6599..21bff41 100644 (file)
@@ -85,14 +85,6 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
 SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
                                   ArrayRef<AffineExpr> b);
 
-/// Generates indexing maps for convolution with the following structure:
-/// input:   (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r)
-/// kernel:  (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r)
-/// output:  (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r)
-/// where r is the rank of the input, kernel and output
-llvm::Optional<SmallVector<AffineMap, 8>>
-createConvNDIndexingMaps(MLIRContext *context, unsigned rank);
-
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc"
 
 #define GET_OP_CLASSES
index 84ae8e4..1e3321a 100644 (file)
@@ -180,131 +180,6 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
   let hasFolder = 1;
 }
 
-class ConvOpBase<string mnemonic, int N>
-  : LinalgStructured_Op<mnemonic, [NInputs<2>, NOutputs<1>]> {
-  let description = [{
-    Base operation for any N-D Convolution implemented as a linalg.generic op.
-
-    Usage:
-
-    ```mlir
-    linalg.conv<N>D(%in, %filter, %out) : memref<(?x)+f32>,
-                                          memref<(?x)+f32>,
-                                          memref<(?x)+f32>
-    ```
-
-    where    %in:     input array
-             %filter: kernel or filter that will be applied on the input array
-             %out:    output array
-
-    and rank of the operands is *N*.
-
-    Every child convolution is expressed as:
-
-    ```mlir
-    #conv_trait = {
-      args_in = 2,
-      args_out = 1,
-      indexing_maps = #conv_accesses,
-      library_call  = "linalg_conv",
-      iterator_types = [("parallel", "parallel")+], // `2 * rank` iterators
-    }
-
-    linalg.generic #conv_trait %in, %filter, %out {
-      ^bb0(%a: f32, %b: f32, %c: f32) :
-        %d = mulf %a, %b : f32
-        %e = addf %c, %d : f32
-        linalg.yield %e : f32
-    } : memref<(?x)+f32>,
-        memref<(?x)+f32>,
-        memref<(?x)+f32>
-    ```
-
-    where #conv_accesses depend on the rank of the operands and thus
-    can be found in the documentation of each N-D case.
-    Please note that the input array is expected to be right-padded i.e.
-    the size of the input is greater than or equal to the size of the output
-    + size of the kernel - 1. If it is not padded the behavior of the op
-    is undefined.
-  }];
-
-  let arguments = (ins AnyStridedMemRefOfRank<N>,
-                       AnyStridedMemRefOfRank<N>,
-                       AnyStridedMemRefOfRank<N>);
-
-  let extraClassDeclaration = libraryCallName # [{
-    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
-      // There are always 2 loops for each dimension of the convolution. First
-      // iterates output and second kernel. Since ranks of all 3 operands must
-      // be the same it does not matter which operand is picked to get the rank.
-      // Loops iterating the output can be parallelized and thus are marked as
-      // "parallel" while loops iterating the kernel are accumulating the
-      // products and therefore are marked as "reduction".
-      unsigned rank = getInputShapedType(0).getRank();
-      SmallVector<StringRef, 8> parallel(rank, getParallelIteratorTypeName());
-      SmallVector<StringRef, 8> reduction(rank, getReductionIteratorTypeName());
-      parallel.insert(parallel.end(), reduction.begin(), reduction.end());
-      return parallel;
-    }
-
-    // Generates indexing maps with the following structure:
-    // input:   (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r)
-    // kernel:  (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r)
-    // output:  (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r)
-    // where r is the rank of the input, kernel and output
-    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
-      MLIRContext *context = getContext();
-      unsigned rank = getInputShapedType(0).getRank();
-      return createConvNDIndexingMaps(context, rank);
-    }
-  }];
-
-  let hasFolder = 1;
-  let verifier = [{ return ::verify(*this); }];
-}
-
-def Conv1DOp : ConvOpBase<"conv1D", 1> {
-  let description = [{
-    *1D* convolution which uses following affine maps to access operands:
-
-    ```mlir
-    #conv_accesses = [
-      affine_map<(m, n) -> (m + n)>, // in
-      affine_map<(m, n) -> (n)>, // kernel
-      affine_map<(m, n) -> (m)> // out
-    ]
-    ```
-  }];
-}
-
-def Conv2DOp : ConvOpBase<"conv2D", 2> {
-  let description = [{
-    *2D* convolution which uses following affine maps to access operands:
-
-    ```mlir
-    #conv_accesses = [
-      affine_map<(m1, m2, n1, n2) -> (m1 + n1, m2 + n2)>, // in
-      affine_map<(m1, m2, n1, n2) -> (n1, n2)>, // kernel
-      affine_map<(m1, m2, n1, n2) -> (m1, m2) // out
-    ]
-    ```
-  }];
-}
-
-def Conv3DOp : ConvOpBase<"conv3D", 3> {
-  let description = [{
-    *3D* convolution which uses following affine maps to access operands:
-
-    ```mlir
-    #conv_accesses = [
-      affine_map<(m1, m2, m3, n1, n2, n3) -> (m1 + n1, m2 + n2, m3 + n3)>, // in
-      affine_map<(m1, m2, m3, n1, n2, n3) -> (n1, n2, n3)>, // kernel
-      affine_map<(m1, m2, m3, n1, n2, n3) -> (m1, m2, m3)> // out
-    ]
-    ```
-  }];
-}
-
 /// A base class for pooling operation such as conv. The arguments must contain
 /// optional arguments `strides`, `dilations` and `padding` with following type:
 ///   OptionalAttr<I64ArrayAttr>:$strides
index 921445b..55ffa3f 100644 (file)
@@ -236,9 +236,6 @@ void mlir::populateLinalgToStandardConversionPatterns(
       LinalgOpConversion<PoolingMinOp>,
       LinalgOpConversion<PoolingSumOp>,
       LinalgOpConversion<CopyOp>,
-      LinalgOpConversion<Conv1DOp>,
-      LinalgOpConversion<Conv2DOp>,
-      LinalgOpConversion<Conv3DOp>,
       LinalgOpConversion<FillOp>,
       LinalgOpConversion<GenericOp>,
       LinalgOpConversion<IndexedGenericOp>>(ctx);
index e67adf8..03bd71f 100644 (file)
@@ -986,17 +986,6 @@ static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
   return success();
 }
 
-template <typename ConvNDOp>
-static LogicalResult verify(ConvNDOp op) {
-  auto outputType = op.getOutputShapedType(0).getElementType();
-  auto inputType = op.getInputShapedType(0).getElementType();
-  auto kernelType = op.getInputShapedType(1).getElementType();
-  if (outputType != inputType || inputType != kernelType)
-    return op.emitOpError("expected all element types of operands to match");
-
-  return success();
-}
-
 static LogicalResult verify(ConvOp op) {
   auto oType = op.output().getType().cast<MemRefType>();
   auto fType = op.filter().getType().cast<MemRefType>();
@@ -1107,27 +1096,6 @@ mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
   return res;
 }
 
-llvm::Optional<SmallVector<AffineMap, 8>>
-mlir::linalg::createConvNDIndexingMaps(MLIRContext *context, unsigned rank) {
-  unsigned numDims = rank * 2, idx = 0;
-
-  SmallVector<AffineExpr, 8> dims, in, kernel, out;
-  dims = makeAffineDimExprs(numDims, idx, context);
-  in.reserve(rank);
-  kernel.reserve(rank);
-  out.reserve(rank);
-
-  for (unsigned i = 0; i < rank; i++) {
-    in.push_back(dims[i] + dims[rank + i]);
-    kernel.push_back(dims[rank + i]);
-    out.push_back(dims[i]);
-  }
-
-  return SmallVector<AffineMap, 8>{AffineMap::get(numDims, 0, in, context),
-                                   AffineMap::get(numDims, 0, kernel, context),
-                                   AffineMap::get(numDims, 0, out, context)};
-}
-
 #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE)                      \
   template SmallVector<AffineExpr, 4>                                          \
   mlir::linalg::weightedPoolingInputIndex<OP_TYPE>(                            \
@@ -1209,18 +1177,6 @@ LogicalResult FillOp::fold(ArrayRef<Attribute>,
                            SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
-LogicalResult Conv1DOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult Conv2DOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult Conv3DOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
                               SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
@@ -1362,3 +1318,39 @@ LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
                              SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
+LogicalResult ConvWOp::fold(ArrayRef<Attribute>,
+                            SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNWCOp::fold(ArrayRef<Attribute>,
+                              SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNCWOp::fold(ArrayRef<Attribute>,
+                              SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvHWOp::fold(ArrayRef<Attribute>,
+                             SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNHWCOp::fold(ArrayRef<Attribute>,
+                               SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNCHWOp::fold(ArrayRef<Attribute>,
+                               SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvDHWOp::fold(ArrayRef<Attribute>,
+                              SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNDHWCOp::fold(ArrayRef<Attribute>,
+                                SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNCDHWOp::fold(ArrayRef<Attribute>,
+                                SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
index db29835..281edd9 100644 (file)
@@ -295,61 +295,6 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
   nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
 }
 
-/// Following functions emit scalar part of the N-D convolution op.
-/// N-D convolution has 2N loops:
-///   1-N: Iterate over the output array *O* with iterators *m1, ..., mN*.
-///   N-2N:. Iterate over the kernel *K* with iterators *n1, ..., nN*.
-///
-/// The scalar part accumulates products of input array *I* values with kernel
-/// ones. The accumulation expression therefore looks like:
-///   O[m1, ..., mN] += I[m1 + n1, ..., mN + nN] * K[n1, ..., nN].
-/// Note that the input array has to be padded in order to prevent
-/// out of bounds accesses.
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, Conv1DOp convOp) {
-  assert(convOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  assert(allIvs.size() == 2);
-  Value m1(allIvs[0]);
-  Value n1(allIvs[1]);
-  IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
-      O(convOp.getOutputBuffer(0));
-  // Emit scalar form for the 1D conv case.
-  Value i1 = m1 + n1;
-  O(m1) = O(m1) + I(i1) * K(n1);
-}
-
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, Conv2DOp convOp) {
-  assert(convOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  assert(allIvs.size() == 4);
-  Value m1(allIvs[0]), m2(allIvs[1]);
-  Value n1(allIvs[2]), n2(allIvs[3]);
-  IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
-      O(convOp.getOutputBuffer(0));
-  // Emit scalar form for the 2D conv case.
-  Value i1 = m1 + n1;
-  Value i2 = m2 + n2;
-  O(m1, m2) = O(m1, m2) + I(i1, i2) * K(n1, n2);
-}
-
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, Conv3DOp convOp) {
-  assert(convOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  assert(allIvs.size() == 6);
-  Value m1(allIvs[0]), m2(allIvs[1]), m3(allIvs[2]);
-  Value n1(allIvs[3]), n2(allIvs[4]), n3(allIvs[5]);
-  IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
-      O(convOp.getOutputBuffer(0));
-  // Emit scalar form for the 3D conv case.
-  Value i1 = m1 + n1;
-  Value i2 = m2 + n2;
-  Value i3 = m3 + n3;
-  O(m1, m2, m3) = O(m1, m2, m3) + I(i1, i2, i3) * K(n1, n2, n3);
-}
-
 template <typename IndexedValueType>
 Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
                      MutableArrayRef<Value> imIdx) {
@@ -738,6 +683,24 @@ static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
     return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
   if (isa<BatchMatmulOp>(op))
     return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
+  if (isa<ConvWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
+  if (isa<ConvNWCOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
+  if (isa<ConvNCWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
+  if (isa<ConvHWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
+  if (isa<ConvNHWCOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
+  if (isa<ConvNCHWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
+  if (isa<ConvDHWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
+  if (isa<ConvNDHWCOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
+  if (isa<ConvNCDHWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
   llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
 }
 
index a5a6e9b..ca59ecd 100644 (file)
@@ -507,11 +507,3 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?x
   linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
   return
 }
-
-// -----
-
-func @conv_type_mismatch(%in: memref<?xi32>, %filter: memref<?xf32>, %out: memref<?xf32>) {
-  // expected-error @+1 {{expected all element types of operands to match}}
-  linalg.conv1D(%in, %filter, %out) : memref<?xi32>, memref<?xf32>, memref<?xf32>
-  return
-}
index ee63d59..6af53a2 100644 (file)
@@ -1288,7 +1288,7 @@ func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out :  m
 //       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
 
 func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
-  linalg.conv1D(%in, %filter, %out) : memref<?xf32>, memref<?xf32>, memref<?xf32>
+  linalg.conv_1d %in, %filter, %out : (memref<?xf32>, memref<?xf32>, memref<?xf32>)
   return
 }
 
@@ -1303,10 +1303,10 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
 //       CHECKLOOP: scf.for %[[b:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
 //       CHECKLOOP:   scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
 //       CHECKLOOP:     %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
-//       CHECKLOOP:     %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
 //       CHECKLOOP:     %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
-//       CHECKLOOP:     %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKLOOP:     %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
 //       CHECKLOOP:     %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKLOOP:     %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
 //       CHECKLOOP:     %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKLOOP:     store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
 
@@ -1318,19 +1318,18 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
 //       CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
 //       CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
 //       CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
-//       CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) {
-//       CHECKPARALLEL:   scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-//       CHECKPARALLEL:     %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
-//       CHECKPARALLEL:     %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
-//       CHECKPARALLEL:     %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
-//       CHECKPARALLEL:     %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-//       CHECKPARALLEL:     %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
-//       CHECKPARALLEL:     %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-//       CHECKPARALLEL:     store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim1]], %[[dim0]]) step (%[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
+//       CHECKPARALLEL:   %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+//       CHECKPARALLEL:   %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+//       CHECKPARALLEL:   %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
 
 
 func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
-  linalg.conv2D(%in, %filter, %out) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+  linalg.conv_2d %in, %filter, %out : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
   return
 }
 // CHECKLOOP-LABEL: @conv2d_no_symbols
@@ -1349,10 +1348,12 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
 //       CHECKLOOP:       scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
 //       CHECKLOOP:         %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
 //       CHECKLOOP:         %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
-//       CHECKLOOP:         %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
 //       CHECKLOOP:         %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
-//       CHECKLOOP:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+
+//       CHECKLOOP:         %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
 //       CHECKLOOP:         %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+
+//       CHECKLOOP:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
 //       CHECKLOOP:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKLOOP:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
 
@@ -1366,21 +1367,19 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
 //       CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
 //       CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
 //       CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
-//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) {
-//       CHECKPARALLEL:   scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-//       CHECKPARALLEL:     scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-//       CHECKPARALLEL:       %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
-//       CHECKPARALLEL:       %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
-//       CHECKPARALLEL:       %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
-//       CHECKPARALLEL:       %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
-//       CHECKPARALLEL:       %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-//       CHECKPARALLEL:       %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
-//       CHECKPARALLEL:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-//       CHECKPARALLEL:       store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
+//       CHECKPARALLEL:   %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
+//       CHECKPARALLEL:   %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
+//       CHECKPARALLEL:   %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
+//       CHECKPARALLEL:   %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+//       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
 
 
 func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
-  linalg.conv3D(%in, %filter, %out) : memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>
+  linalg.conv_3d %in, %filter, %out : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
   return
 }
 
@@ -1406,10 +1405,12 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
 //       CHECKLOOP:             %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
 //       CHECKLOOP:             %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
 //       CHECKLOOP:             %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
-//       CHECKLOOP:             %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
 //       CHECKLOOP:             %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
-//       CHECKLOOP:             %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+
+//       CHECKLOOP:             %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
 //       CHECKLOOP:             %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+
+//       CHECKLOOP:             %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
 //       CHECKLOOP:             %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKLOOP:             store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
 
@@ -1426,16 +1427,13 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
 //       CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
 //       CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
 //       CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
-//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) {
-//       CHECKPARALLEL:   scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-//       CHECKPARALLEL:     scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-//       CHECKPARALLEL:       scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
-//       CHECKPARALLEL:         %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
-//       CHECKPARALLEL:         %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
-//       CHECKPARALLEL:         %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
-//       CHECKPARALLEL:         %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
-//       CHECKPARALLEL:         %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
-//       CHECKPARALLEL:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-//       CHECKPARALLEL:         %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
-//       CHECKPARALLEL:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-//       CHECKPARALLEL:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]], %[[arg7:.*]], %[[arg8:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
+//       CHECKPARALLEL:   %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
+//       CHECKPARALLEL:   %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
+//       CHECKPARALLEL:   %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:   %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:   %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>