Templatize linalg::LowerToLoops - NFC
authorNicolas Vasilache <ntv@google.com>
Fri, 15 Nov 2019 15:12:17 +0000 (07:12 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 15 Nov 2019 15:12:51 +0000 (07:12 -0800)
This modification will allow to easily plug lowering of linalg ops to different types of loops (affine, loop.for and other future constructs).
This is purely NFC for now.

PiperOrigin-RevId: 280652186

mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp

index 1b30093..058dc07 100644 (file)
@@ -40,7 +40,7 @@ using namespace mlir::edsc::intrinsics;
 using namespace mlir::linalg;
 using namespace mlir::linalg::intrinsics;
 
-using IndexedLinalgValue = TemplatedIndexedValue<std_load, std_store>;
+using IndexedStdValue = TemplatedIndexedValue<std_load, std_store>;
 using edsc::op::operator+;
 using edsc::op::operator==;
 
@@ -76,7 +76,11 @@ static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
 static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
                                               AffineMap map,
                                               ArrayRef<Value *> allViewSizes,
-                                              OperationFolder *folder) {
+                                              OperationFolder *folder);
+SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
+                                       AffineMap map,
+                                       ArrayRef<Value *> allViewSizes,
+                                       OperationFolder *folder) {
   // Apply `map` to get view sizes in loop order.
   auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder);
   // Create a new range with the applied tile sizes.
@@ -89,9 +93,11 @@ static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
   return res;
 }
 
-template <typename LinalgOpType> class LinalgScopedEmitter {};
+template <typename IndexedValueType, typename LinalgOpType>
+class LinalgScopedEmitter {};
 
-template <> class LinalgScopedEmitter<CopyOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, CopyOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp,
                                        OperationFolder *folder) {
@@ -103,7 +109,7 @@ public:
         permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder);
     SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
     SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
-    IndexedLinalgValue O(copyOp.getOutput(0)), I(copyOp.getInput(0));
+    IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0));
     // Emit the proper scalar assignment, whether we are dealing with a 0-D or
     // an n-D loop nest; with or without permutations.
     // clang-format off
@@ -113,7 +119,8 @@ public:
   }
 };
 
-template <> class LinalgScopedEmitter<FillOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, FillOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp,
                                        OperationFolder *folder) {
@@ -121,7 +128,7 @@ public:
     assert(nPar == allIvs.size());
     auto ivs =
         SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
-    IndexedLinalgValue O(fillOp.getOutput(0));
+    IndexedValueType O(fillOp.getOutput(0));
     // Emit the proper scalar assignment, whether we are dealing with a 0-D or
     // an n-D loop nest; with or without permutations.
     nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue())
@@ -129,48 +136,52 @@ public:
   }
 };
 
-template <> class LinalgScopedEmitter<DotOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, DotOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp,
                                        OperationFolder *folder) {
     assert(allIvs.size() == 1);
     IndexHandle r_i(allIvs[0]);
-    IndexedLinalgValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
+    IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
         C(dotOp.getOutput(0));
     // Emit scalar form.
     C() = C() + A(r_i) * B(r_i);
   }
 };
 
-template <> class LinalgScopedEmitter<MatvecOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs,
                                        MatvecOp matvecOp,
                                        OperationFolder *folder) {
     assert(allIvs.size() == 2);
     IndexHandle i(allIvs[0]), r_j(allIvs[1]);
-    IndexedLinalgValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
+    IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
         C(matvecOp.getOutput(0));
     // Emit scalar form.
     C(i) = C(i) + A(i, r_j) * B(r_j);
   }
 };
 
-template <> class LinalgScopedEmitter<MatmulOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs,
                                        MatmulOp matmulOp,
                                        OperationFolder *folder) {
     assert(allIvs.size() == 3);
     IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
-    IndexedLinalgValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
+    IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
         C(matmulOp.getOutput(0));
     // Emit scalar form.
     C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
   }
 };
 
-template <> class LinalgScopedEmitter<ConvOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, ConvOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp,
                                        OperationFolder *folder) {
@@ -183,8 +194,7 @@ public:
         foldedAffineApplies(b, loc, maps[1], allIvs, folder));
     SmallVector<ValueHandle, 8> oIdx(
         foldedAffineApplies(b, loc, maps[2], allIvs, folder));
-    IndexedLinalgValue F(convOp.filter()), I(convOp.input()),
-        O(convOp.output());
+    IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
     // Emit scalar form.
     O(oIdx) += F(fIdx) * I(imIdx);
   }
@@ -205,22 +215,23 @@ public:
 //    loop.for %i = %c0 to %0 step %c1 {
 //      loop.for %j = %c0 to %1 step %c1 {
 //        loop.for %k = %c0 to %4 step %c1 {
-//          %11 = linalg.load %arg0[%i, %j] :
+//          %11 = load %arg0[%i, %j] :
 //            memref<?x?xf32, stride_specification>
-//          %12 = linalg.load %arg1[%i, %j, %k] :
+//          %12 = load %arg1[%i, %j, %k] :
 //            memref<?x?x?xf32, stride_specification>
-//          %13 = linalg.load %arg2[%i, %k, %j] :
+//          %13 = load %arg2[%i, %k, %j] :
 //            memref<?x?x?xf32, stride_specification>
 //          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
-//          linalg.store %14#0, %arg1[%i, %j, %k] :
+//          store %14#0, %arg1[%i, %j, %k] :
 //            memref<?x?x?Xf32, stride_specification>
-//          linalg.store %14#1, %arg2[%i, %k, %j] :
+//          store %14#1, %arg2[%i, %k, %j] :
 //            memref<?x?x?Xf32, stride_specification>
 //       }
 //      }
 //    }
 // ```
-template <> class LinalgScopedEmitter<GenericOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, GenericOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs,
                                        GenericOp genericOp,
@@ -288,7 +299,8 @@ public:
   }
 };
 
-template <> class LinalgScopedEmitter<IndexedGenericOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs,
                                        IndexedGenericOp genericOp,
@@ -298,7 +310,7 @@ public:
   }
 };
 
-template <typename ConcreteOp>
+template <typename LoopType, typename IndexedValueType, typename ConcreteOp>
 class LinalgRewritePattern : public RewritePattern {
 public:
   explicit LinalgRewritePattern(MLIRContext *context)
@@ -316,7 +328,8 @@ public:
     auto invertedMap =
         inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
     if (!invertedMap) {
-      LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation({}, linalgOp,
+      LinalgScopedEmitter<IndexedValueType,
+                          ConcreteOp>::emitScalarImplementation({}, linalgOp,
                                                                 &folder);
       rewriter.eraseOp(op);
       return matchSuccess();
@@ -327,28 +340,18 @@ public:
     auto nWin = linalgOp.getNumWindowLoops();
     SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
     SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
-    auto pivs = MutableArrayRef<ValueHandle *>(allPIvs).take_front(nPar);
-    auto rivs = MutableArrayRef<ValueHandle *>(allPIvs)
-                    .take_front(nPar + nRed)
-                    .take_back(nRed);
-    auto wivs = MutableArrayRef<ValueHandle *>(allPIvs).take_back(nWin);
-
     auto loopRanges =
         emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
                        getViewSizes(linalgOp), &folder);
-    assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size());
-
-    // clang-format off
-    ArrayRef<Value *> ranges(loopRanges);
-    LoopNestRangeBuilder(pivs, ranges.take_front(nPar))([&] {
-      LoopNestRangeBuilder(rivs, ranges.drop_back(nWin).take_back(nRed))([&] {
-        LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))(
-          [&linalgOp, &allIvs, this] {
-            auto allIvValues = extractValues(allIvs);
-            LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation(
-                allIvValues, linalgOp, &folder);
-        });
-      });
+    assert(loopRanges.size() == allIvs.size());
+
+    // clang-format off;
+    LoopNestRangeBuilder(allPIvs, loopRanges)([&] {
+      auto allIvValues = extractValues(allIvs);
+      LinalgScopedEmitter<IndexedValueType,
+                          ConcreteOp>::emitScalarImplementation(allIvValues,
+                                                                linalgOp,
+                                                                &folder);
     });
     // clang-format on
     rewriter.eraseOp(op);
@@ -359,56 +362,68 @@ public:
 };
 
 // Helper classes for type list expansion.
-template <typename... LinalgOps> class ConversionList;
+template <typename LoopType, typename IndexedValueType, typename... LinalgOps>
+class ConversionList;
 
-template <> class ConversionList<> {
+template <typename LoopType, typename IndexedValueType>
+class ConversionList<LoopType, IndexedValueType> {
 public:
   static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
 };
 
-template <typename ConcreteOp, typename... LinalgOps>
-class ConversionList<ConcreteOp, LinalgOps...> {
+template <typename LoopType, typename IndexedValueType, typename ConcreteOp,
+          typename... LinalgOps>
+class ConversionList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
 public:
   static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
-    patterns.insert<LinalgRewritePattern<ConcreteOp>>(ctx);
-    ConversionList<LinalgOps...>::build(patterns, ctx);
+    patterns
+        .insert<LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>(
+            ctx);
+    ConversionList<LoopType, IndexedValueType, LinalgOps...>::build(patterns,
+                                                                    ctx);
   }
 };
 
 /// Populate the given list with patterns that convert from Linalg to LLVM.
-static void
-populateLinalgToLoopRewritePatterns(OwningRewritePatternList &patterns,
-                                    MLIRContext *ctx) {
-  ConversionList<
+template <typename LoopType, typename IndexedValueType>
+void ForOpRewritePatterns(OwningRewritePatternList &patterns,
+                          MLIRContext *ctx) {
+  ConversionList<LoopType, IndexedValueType,
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
-      >::build(patterns, ctx);
+                 >::build(patterns, ctx);
 }
 
 namespace {
-struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
+template <typename LoopType, typename IndexedValueType>
+struct LowerLinalgToLoopsPass
+    : public FunctionPass<LowerLinalgToLoopsPass<LoopType, IndexedValueType>> {
   void runOnFunction() override;
 };
 } // namespace
 
-void LowerLinalgToLoopsPass::runOnFunction() {
+template <typename LoopType, typename IndexedValueType>
+void LowerLinalgToLoopsPass<LoopType, IndexedValueType>::runOnFunction() {
   OwningRewritePatternList patterns;
-  populateLinalgToLoopRewritePatterns(patterns, &getContext());
+  ForOpRewritePatterns<LoopType, IndexedValueType>(patterns,
+                                                   &this->getContext());
 
-  ConversionTarget target(getContext());
+  ConversionTarget target(this->getContext());
   target.addLegalDialect<AffineOpsDialect>();
   target.addLegalDialect<loop::LoopOpsDialect>();
   target.addLegalDialect<StandardOpsDialect>();
-  if (failed(applyPartialConversion(getFunction(), target, patterns))) {
-    signalPassFailure();
+  if (failed(applyPartialConversion(this->getFunction(), target, patterns))) {
+    this->signalPassFailure();
   }
 }
 
 std::unique_ptr<OpPassBase<FuncOp>>
 mlir::linalg::createLowerLinalgToLoopsPass() {
-  return std::make_unique<LowerLinalgToLoopsPass>();
+  return std::make_unique<
+      LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>();
 }
 
-static PassRegistration<LowerLinalgToLoopsPass>
-    pass("linalg-lower-to-loops",
-         "Lower the operations from the linalg dialect into loops");
+static PassRegistration<LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>
+    structuredLoopsPass(
+        "linalg-lower-to-loops",
+        "Lower the operations from the linalg dialect into loops");