[mlir][linalg] Remove IndexedGenericOp support from LinalgToLoops...
authorTobias Gysi <gysit@google.com>
Tue, 11 May 2021 06:52:44 +0000 (06:52 +0000)
committerTobias Gysi <gysit@google.com>
Tue, 11 May 2021 06:53:47 +0000 (06:53 +0000)
after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612).

Differential Revision: https://reviews.llvm.org/D102187

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/test/Dialect/Linalg/loops.mlir

index 5d5dc2a..b1bf213 100644 (file)
@@ -378,84 +378,6 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
       getPoolingInput<IndexedValueType>(op, indices.inputs);
 }
 
-/// Emits the MLIR for the scalar part of the indexed generic op by:
-///   1. Emitting load ops for each input and output view in order. This is
-///      achieved by applying the appropriate input or output map to the
-///      enclosing induction variables.
-///   2. Emitting a call to `op.fun()` that takes as arguments the induction
-///      variables and the scalars from point 1. above.
-///   3. Emitting store ops to store the results of 2. to the output views.
-///
-/// An example output may resemble:
-///
-/// ```
-///    scf.for %i = %c0 to %0 step %c1 {
-///      scf.for %j = %c0 to %1 step %c1 {
-///        scf.for %k = %c0 to %4 step %c1 {
-///          %11 = load %arg0[%i, %j] :
-///            memref<?x?xf32, stride_specification>
-///          %12 = load %arg1[%i, %j, %k] :
-///            memref<?x?x?xf32, stride_specification>
-///          %13 = load %arg2[%i, %k, %j] :
-///            memref<?x?x?xf32, stride_specification>
-///          %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
-///            (index, index, index, f32, f32, f32) -> (f32, f32)
-///          store %14#0, %arg1[%i, %j, %k] :
-///            memref<?x?x?Xf32, stride_specification>
-///          store %14#1, %arg2[%i, %k, %j] :
-///            memref<?x?x?Xf32, stride_specification>
-///       }
-///      }
-///    }
-/// ```
-template <typename IndexedValueType>
-static void emitScalarImplementation(ArrayRef<Value> allIvs,
-                                     IndexedGenericOp indexedGenericOp) {
-  assert(indexedGenericOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  auto &b = ScopedContext::getBuilderRef();
-  auto loc = ScopedContext::getLocation();
-  unsigned nInputs = indexedGenericOp.getNumInputs();
-  unsigned nOutputs = indexedGenericOp.getNumOutputs();
-  unsigned nLoops = allIvs.size();
-  SmallVector<Value, 4> indexedValues;
-  indexedValues.reserve(nLoops + nInputs + nOutputs);
-  for (unsigned i = 0; i < nLoops; ++i)
-    indexedValues.push_back(allIvs[i]);
-
-  // TODO: Avoid the loads if the corresponding argument of the
-  // region has no uses.
-  // 1.a. Emit load from input views.
-  for (unsigned i = 0; i < nInputs; ++i) {
-    auto indexing = makeCanonicalAffineApplies(
-        b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs);
-    // Pass input i through IndexedValueType emits the proper load operation.
-    indexedValues.push_back(
-        IndexedValueType(indexedGenericOp.getInput(i))(indexing));
-  }
-  // 1.b. Emit load from output views.
-  for (unsigned i = 0; i < nOutputs; ++i) {
-    auto indexing = makeCanonicalAffineApplies(
-        b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs);
-    // Pass output i through IndexedValueType emits the proper load operation.
-    indexedValues.push_back(
-        IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing));
-  }
-
-  // TODO: When a region inliner exists, use it.
-  // 2. Inline region, currently only works for a single basic block.
-  // 3. Emit store.
-  SmallVector<SmallVector<Value, 8>, 8> indexing;
-  SmallVector<Value, 8> outputBuffers;
-  for (unsigned i = 0; i < nOutputs; ++i) {
-    indexing.push_back(makeCanonicalAffineApplies(
-        b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
-    outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
-  }
-  inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues,
-                                             indexing, outputBuffers);
-}
-
 template <typename LoopTy>
 static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
                                                  OpBuilder &builder) {
@@ -477,10 +399,10 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
         assert(iterArgs.empty() && "unexpected iterArgs");
         allIvs.append(ivs.begin(), ivs.end());
         llvm::TypeSwitch<Operation *>(linalgOp)
-            .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp,
-                  IndexedGenericOp, LinalgOp>([&](auto op) {
-              emitScalarImplementation<IndexedValueTy>(allIvs, op);
-            })
+            .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>(
+                [&](auto op) {
+                  emitScalarImplementation<IndexedValueTy>(allIvs, op);
+                })
             .Default([&](Operation *op) { assert(false && "unexpected op"); });
         return scf::ValueVector{};
       });
@@ -697,6 +619,10 @@ template <typename LoopTy>
 Optional<LinalgLoops>
 mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter,
                                    LinalgOp linalgOp) {
+  // Convert indexed_generic ops to generic ops before lowering them to loops.
+  if (isa<IndexedGenericOp>(linalgOp))
+    return llvm::None;
+
   Optional<LinalgLoops> loopOps =
       linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter);
   if (loopOps.hasValue())
index b0ffc7f..b958e0a 100644 (file)
@@ -935,58 +935,6 @@ func @generic_index_region(
 //       CHECKPARALLEL:   store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
 //       CHECKPARALLEL:   store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
 
-func @indexed_generic_region(
-        %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
-        %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
-        %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.indexed_generic #trait4
-      ins(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
-     outs(%arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
-                         memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-    ^bb0(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
-      %result_1 = mulf %a, %b : f32
-
-      %ij = addi %i, %j : index
-      %ijk = addi %ij, %k : index
-      %ijk_int = index_cast %ijk : index to i32
-      %ijk_float = sitofp %ijk_int : i32 to f32
-
-      %result_2 = addf %c, %ijk_float : f32
-      linalg.yield %result_1, %result_2 : f32, f32
-  }
-  return
-}
-
-// CHECKLOOP-LABEL: @indexed_generic_region
-//       CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
-//       CHECKLOOP:   scf.for %[[j:.*]] = {{.*}}
-//       CHECKLOOP:     scf.for %[[k:.*]] = {{.*}}
-//       CHECKLOOP:       %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
-//       CHECKLOOP:       %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
-//       CHECKLOOP:       %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
-//       CHECKLOOP:       %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
-//       CHECKLOOP:       %[[ij:.*]] = addi %[[i]], %[[j]] : index
-//       CHECKLOOP:       %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
-//       CHECKLOOP:       %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
-//       CHECKLOOP:       %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
-//       CHECKLOOP:       %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
-//       CHECKLOOP:       store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
-//       CHECKLOOP:       store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
-
-// CHECKPARALLEL-LABEL: @indexed_generic_region
-//       CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]])
-//       CHECKPARALLEL:   %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
-//       CHECKPARALLEL:   %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
-//       CHECKPARALLEL:   %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
-//       CHECKPARALLEL:   %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
-//       CHECKPARALLEL:   %[[ij:.*]] = addi %[[i]], %[[j]] : index
-//       CHECKPARALLEL:   %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
-//       CHECKPARALLEL:   %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
-//       CHECKPARALLEL:   %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
-//       CHECKPARALLEL:   %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
-//       CHECKPARALLEL:   store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
-//       CHECKPARALLEL:   store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
-
 // -----
 
 #broadcast_access = [
@@ -1065,41 +1013,6 @@ func @generic_index_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
 //       CHECKPARALLEL:   %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
 //       CHECKPARALLEL:   store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
 
-func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
-{
-  linalg.indexed_generic #trait_broadcast
-      ins(%arg0 : memref<i32>)
-     outs(%arg1 : memref<3x4xi32>) {
-    ^bb(%i: index, %j: index, %a: i32, %b: i32) :
-      %ij = addi %i, %j : index
-      %ij_int = index_cast %ij : index to i32
-      %result = addi %a, %ij_int : i32
-      linalg.yield %result : i32
-  }
-  return
-}
-
-// CHECKLOOP-LABEL: @indexed_generic_op_zero_rank
-//  CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
-//  CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
-//       CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
-//       CHECKLOOP:   scf.for %[[j:.*]] = {{.*}}
-//       CHECKLOOP:     %[[a:.*]] = memref.load %[[ARG0]][
-//       CHECKLOOP:     %[[ij:.*]] = addi %[[i]], %[[j]] : index
-//       CHECKLOOP:     %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
-//       CHECKLOOP:     %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
-//       CHECKLOOP:     store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
-
-// CHECKPARALLEL-LABEL: @indexed_generic_op_zero_rank
-//  CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
-//  CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
-//       CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]])
-//       CHECKPARALLEL:   %[[a:.*]] = memref.load %[[ARG0]][
-//       CHECKPARALLEL:   %[[ij:.*]] = addi %[[i]], %[[j]] : index
-//       CHECKPARALLEL:   %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
-//       CHECKPARALLEL:   %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
-//       CHECKPARALLEL:   store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
-
 #reduce_1D_access = [
   affine_map<(i) -> (i)>,
   affine_map<(i) -> ()>
@@ -1198,46 +1111,6 @@ func @generic_index_op_1D_reduce(%arg0: memref<?xf32>,
 //       CHECKPARALLEL:   %[[e:.*]] = addf %[[a]], %[[d]]
 //       CHECKPARALLEL:   store %[[e]], %[[ARG2]][]
 
-func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
-                                   %arg1: memref<f32>,
-                                   %arg2: memref<f32>)
-{
-  linalg.indexed_generic #trait_reduce_init_1D
-      ins(%arg0, %arg1 : memref<?xf32>, memref<f32>)
-     outs(%arg2 : memref<f32>) {
-    ^bb(%i : index, %a: f32, %b: f32, %c: f32) :
-      %0 = constant 0 : index
-      %1 = cmpi eq, %0, %i : index
-      %2 = select %1, %b, %c : f32
-      %3 = addf %a, %2 : f32
-      linalg.yield %3 : f32
-  }
-  return
-}
-// CHECKLOOP-LABEL: @indexed_generic_op_1D_reduce
-//  CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
-//  CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
-//  CHECKLOOP-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
-//       CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
-//       CHECKLOOP:   %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
-//       CHECKLOOP:   %[[b:.*]] = memref.load %[[ARG1]][]
-//       CHECKLOOP:   %[[c:.*]] = memref.load %[[ARG2]][]
-//       CHECKLOOP:   %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
-//       CHECKLOOP:   %[[e:.*]] = addf %[[a]], %[[d]]
-//       CHECKLOOP:   store %[[e]], %[[ARG2]][]
-
-// CHECKPARALLEL-LABEL: @indexed_generic_op_1D_reduce
-//  CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
-//  CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
-//  CHECKPARALLEL-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
-//       CHECKPARALLEL: scf.for %[[i:.*]] = {{.*}}
-//       CHECKPARALLEL:   %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
-//       CHECKPARALLEL:   %[[b:.*]] = memref.load %[[ARG1]][]
-//       CHECKPARALLEL:   %[[c:.*]] = memref.load %[[ARG2]][]
-//       CHECKPARALLEL:   %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
-//       CHECKPARALLEL:   %[[e:.*]] = addf %[[a]], %[[d]]
-//       CHECKPARALLEL:   store %[[e]], %[[ARG2]][]
-
 #trait_const_fill = {
   args_in = 0,
   args_out = 1,