Standardize Linalg transformations to take an OpBuilder and an OperationFolder - NFC
authorNicolas Vasilache <ntv@google.com>
Mon, 28 Oct 2019 21:55:43 +0000 (14:55 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 28 Oct 2019 21:56:20 +0000 (14:56 -0700)
This will be used to specify declarative Linalg transformations in a followup CL. In particular, the PatternRewrite mechanism does not allow folding and has its own way of tracking erasure.

PiperOrigin-RevId: 277149158

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/EDSC/Builders.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp

index a8daaac..2b2fdfb 100644 (file)
@@ -83,9 +83,13 @@ struct FusionInfo {
 
 // Fuses producer into consumer if the producer is structurally feasible and the
 // fusion would not violate dependencies.
-Optional<FusionInfo> fuseProducerOf(LinalgOp consumer, unsigned consumerIdx,
+/// When non-null, the optional pointer `folder` is used to call into the
+/// `createAndFold` builder method. If `folder` is null, the regular `create`
+/// method is called.
+Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer,
+                                    unsigned consumerIdx,
                                     LinalgDependenceGraph &graph,
-                                    OperationFolder &state);
+                                    OperationFolder *folder = nullptr);
 
 /// Returns the linearized list of all view dimensions in a linalgOp. Applying
 /// the inverse, concatenated loopToOperandRangeMaps to this list allows the
@@ -102,11 +106,13 @@ SmallVector<Value *, 8> getViewSizes(ConcreteOp linalgOp) {
 }
 
 /// Returns the values obtained by applying `map` to the list of values.
-/// Performs simplifications and foldings where possible.
+/// When non-null, the optional pointer `folder` is used to call into the
+/// `createAndFold` builder method. If `folder` is null, the regular `create`
+/// method is called.
 SmallVector<Value *, 4> applyMapToValues(OpBuilder &b, Location loc,
                                          AffineMap map,
                                          ArrayRef<Value *> values,
-                                         OperationFolder &state);
+                                         OperationFolder *folder = nullptr);
 
 struct TiledLinalgOp {
   LinalgOp op;
@@ -116,14 +122,28 @@ struct TiledLinalgOp {
 /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
 /// Returns a struct containing the tiled loops and the cloned op if successful,
 /// llvm::None otherwise.
-llvm::Optional<TiledLinalgOp>
-tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes, OperationFolder &folder);
+/// When non-null, the optional pointer `folder` is used to call into the
+/// `createAndFold` builder method. If `folder` is null, the regular `create`
+/// method is called.
+llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
+                                           ArrayRef<Value *> tileSizes,
+                                           OperationFolder *folder = nullptr);
 
 /// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
 /// Returns a struct containing the tiled loops and the cloned op if successful,
 /// llvm::None otherwise.
-llvm::Optional<TiledLinalgOp>
-tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes, OperationFolder &folder);
+/// When non-null, the optional pointer `folder` is used to call into the
+/// `createAndFold` builder method. If `folder` is null, the regular `create`
+/// method is called.
+llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
+                                           ArrayRef<int64_t> tileSizes,
+                                           OperationFolder *folder = nullptr);
+
+template <typename... Args>
+llvm::Optional<TiledLinalgOp> tileLinalgOperation(OpBuilder &b, Operation *op,
+                                                  Args... args) {
+  return tileLinalgOp(b, cast<LinalgOp>(op), args...);
+}
 
 struct PromotionInfo {
   Value *buffer;
@@ -142,7 +162,7 @@ struct PromotionInfo {
 /// full and partial views indexing into the buffer.
 llvm::SmallVector<PromotionInfo, 8> promoteSubViews(OpBuilder &b, Location loc,
                                                     ArrayRef<Value *> subViews,
-                                                    OperationFolder &folder);
+                                                    OperationFolder *folder);
 
 /// Returns all the operands of `linalgOp` that are not views.
 /// Asserts that these operands are value types to allow transformations like
index 5a80571..1927ce6 100644 (file)
@@ -348,8 +348,11 @@ public:
 
   /// Generic mlir::Op create. This is the key to being extensible to the whole
   /// of MLIR without duplicating the type system or the op definitions.
+  /// When non-null, the optional pointer `folder` is used to call into the
+  /// `createAndFold` builder method. If `folder` is null, the regular `create`
+  /// method is called.
   template <typename Op, typename... Args>
-  static ValueHandle create(OperationFolder &folder, Args... args);
+  static ValueHandle create(OperationFolder *folder, Args... args);
 
   /// Special case to build composed AffineApply operations.
   // TODO: createOrFold when available and move inside of the `create` method.
@@ -497,9 +500,12 @@ ValueHandle ValueHandle::create(Args... args) {
 }
 
 template <typename Op, typename... Args>
-ValueHandle ValueHandle::create(OperationFolder &folder, Args... args) {
-  return ValueHandle(folder.create<Op>(ScopedContext::getBuilder(),
-                                       ScopedContext::getLocation(), args...));
+ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) {
+  return folder ? ValueHandle(folder->create<Op>(ScopedContext::getBuilder(),
+                                                 ScopedContext::getLocation(),
+                                                 args...))
+                : ValueHandle(ScopedContext::getBuilder().create<Op>(
+                      ScopedContext::getLocation(), args...));
 }
 
 namespace op {
index c54dffe..ebdb32b 100644 (file)
@@ -75,10 +75,7 @@ static llvm::cl::list<unsigned> clTileSizes(
 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
 // to the `loopRanges` in order to obtain view ranges.
 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
-                                    ArrayRef<SubViewOp::Range> loopRanges,
-                                    OperationFolder &state) {
-  ScopedContext scope(b, loc);
-
+                                    ArrayRef<SubViewOp::Range> loopRanges) {
   auto maps = loopToOperandRangesMaps(op);
   SmallVector<Value *, 8> clonedViews;
   clonedViews.reserve(op.getNumInputsAndOutputs());
@@ -152,7 +149,7 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
 
 static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer,
                      unsigned consumerIdx, unsigned producerIdx,
-                     OperationFolder &state) {
+                     OperationFolder *folder) {
   auto subView = dyn_cast_or_null<SubViewOp>(
       consumer.getInput(consumerIdx)->getDefiningOp());
   auto slice = dyn_cast_or_null<SliceOp>(
@@ -192,15 +189,14 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer,
                  << "existing LoopRange: " << loopRanges[i] << "\n");
     else {
       auto viewDim = getViewDefiningLoopRange(producer, i);
-      loopRanges[i] =
-          SubViewOp::Range{state.create<ConstantIndexOp>(b, loc, 0),
-                           dim(viewDim.view, viewDim.dimension),
-                           state.create<ConstantIndexOp>(b, loc, 1)};
+      loopRanges[i] = SubViewOp::Range{constant_index(folder, 0),
+                                       dim(viewDim.view, viewDim.dimension),
+                                       constant_index(folder, 1)};
       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
     }
   }
 
-  return cloneWithLoopRanges(b, loc, producer, loopRanges, state);
+  return cloneWithLoopRanges(b, loc, producer, loopRanges);
 }
 
 // Encode structural fusion safety preconditions.
@@ -231,10 +227,11 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
 }
 
 // Only consider RAW atm.
-Optional<FusionInfo> mlir::linalg::fuseProducerOf(LinalgOp consumer,
+Optional<FusionInfo> mlir::linalg::fuseProducerOf(OpBuilder &b,
+                                                  LinalgOp consumer,
                                                   unsigned consumerIdx,
                                                   LinalgDependenceGraph &graph,
-                                                  OperationFolder &state) {
+                                                  OperationFolder *folder) {
   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
                     << *consumer.getOperation());
   for (auto dependence : graph.getDependencesInto(
@@ -270,11 +267,12 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(LinalgOp consumer,
       continue;
 
     // Fuse `producer` just before `consumer`.
-    OpBuilder builder(consumer.getOperation());
-    ScopedContext scope(builder, consumer.getLoc());
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(consumer.getOperation());
+    ScopedContext scope(b, consumer.getLoc());
     LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
-    auto fusedProducer =
-        fuse(producedView, producer, consumer, consumerIdx, producerIdx, state);
+    auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
+                              producerIdx, folder);
 
     return FusionInfo{producer, fusedProducer};
   }
@@ -284,7 +282,8 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(LinalgOp consumer,
 static void fuseLinalgOpsGreedily(FuncOp f) {
   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
 
-  OperationFolder state(f.getContext());
+  OpBuilder b(f);
+  OperationFolder folder(f.getContext());
   DenseSet<Operation *> eraseSet;
 
   // Save original Linalg ops, we only want to make a pass over those.
@@ -296,7 +295,7 @@ static void fuseLinalgOpsGreedily(FuncOp f) {
   for (auto *op : llvm::reverse(linalgOps)) {
     for (unsigned consumerIdx = 0, e = LinalgOp(op).getNumInputs();
          consumerIdx < e; ++consumerIdx) {
-      if (auto fusionInfo = fuseProducerOf(op, consumerIdx, G, state))
+      if (auto fusionInfo = fuseProducerOf(b, op, consumerIdx, G, &folder))
         eraseSet.insert(fusionInfo->originalProducer.getOperation());
     }
   }
index e6070a6..6f6b2fc 100644 (file)
@@ -46,7 +46,7 @@ using edsc::op::operator==;
 
 static SmallVector<ValueHandle, 8>
 foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
-                    ArrayRef<Value *> vals, OperationFolder &folder) {
+                    ArrayRef<Value *> vals, OperationFolder *folder) {
   assert(map.getNumSymbols() == 0);
   assert(map.getNumInputs() == vals.size());
   SmallVector<ValueHandle, 8> res;
@@ -63,10 +63,10 @@ foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
 
 static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
                                           Optional<AffineMap> permutation,
-                                          OperationFolder &state) {
+                                          OperationFolder *folder) {
   return permutation ? applyMapToValues(ScopedContext::getBuilder(),
                                         ScopedContext::getLocation(),
-                                        permutation.getValue(), ivs, state)
+                                        permutation.getValue(), ivs, folder)
                      : SmallVector<Value *, 4>(ivs.begin(), ivs.end());
 }
 
@@ -76,7 +76,7 @@ static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
 static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
                                               AffineMap map,
                                               ArrayRef<Value *> allViewSizes,
-                                              OperationFolder &folder) {
+                                              OperationFolder *folder) {
   // Apply `map` to get view sizes in loop order.
   auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder);
   // Create a new range with the applied tile sizes.
@@ -94,7 +94,7 @@ template <typename LinalgOpType> class LinalgScopedEmitter {};
 template <> class LinalgScopedEmitter<CopyOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp,
-                                       OperationFolder &folder) {
+                                       OperationFolder *folder) {
     auto nPar = copyOp.getNumParallelLoops();
     assert(nPar == allIvs.size());
     auto inputIvs =
@@ -116,7 +116,7 @@ public:
 template <> class LinalgScopedEmitter<FillOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp,
-                                       OperationFolder &folder) {
+                                       OperationFolder *folder) {
     auto nPar = fillOp.getNumParallelLoops();
     assert(nPar == allIvs.size());
     auto ivs =
@@ -132,7 +132,7 @@ public:
 template <> class LinalgScopedEmitter<DotOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp,
-                                       OperationFolder &folder) {
+                                       OperationFolder *folder) {
     assert(allIvs.size() == 1);
     IndexHandle r_i(allIvs[0]);
     IndexedLinalgValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
@@ -146,7 +146,7 @@ template <> class LinalgScopedEmitter<MatvecOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs,
                                        MatvecOp matvecOp,
-                                       OperationFolder &folder) {
+                                       OperationFolder *folder) {
     assert(allIvs.size() == 2);
     IndexHandle i(allIvs[0]), r_j(allIvs[1]);
     IndexedLinalgValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
@@ -160,7 +160,7 @@ template <> class LinalgScopedEmitter<MatmulOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs,
                                        MatmulOp matmulOp,
-                                       OperationFolder &folder) {
+                                       OperationFolder *folder) {
     assert(allIvs.size() == 3);
     IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
     IndexedLinalgValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
@@ -173,7 +173,7 @@ public:
 template <> class LinalgScopedEmitter<ConvOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp,
-                                       OperationFolder &folder) {
+                                       OperationFolder *folder) {
     auto b = ScopedContext::getBuilder();
     auto loc = ScopedContext::getLocation();
     auto maps = loopToOperandRangesMaps(convOp);
@@ -224,7 +224,7 @@ template <> class LinalgScopedEmitter<GenericOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs,
                                        GenericOp genericOp,
-                                       OperationFolder &folder) {
+                                       OperationFolder *folder) {
     auto b = ScopedContext::getBuilder();
     auto loc = ScopedContext::getLocation();
     using edsc::intrinsics::detail::ValueHandleArray;
@@ -307,7 +307,7 @@ public:
         inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
     if (!invertedMap) {
       LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation({}, linalgOp,
-                                                                folder);
+                                                                &folder);
       rewriter.eraseOp(op);
       return matchSuccess();
     }
@@ -325,7 +325,7 @@ public:
 
     auto loopRanges =
         emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
-                       getViewSizes(linalgOp), folder);
+                       getViewSizes(linalgOp), &folder);
     assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size());
 
     // clang-format off
@@ -336,7 +336,7 @@ public:
           [&linalgOp, &allIvs, this] {
             auto allIvValues = extractValues(allIvs);
             LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation(
-                allIvValues, linalgOp, folder);
+                allIvValues, linalgOp, &folder);
         });
       });
     });
index 5f73661..c9b7435 100644 (file)
@@ -81,7 +81,7 @@ static Value *allocBuffer(Type elementType, Value *size) {
 // by a partial `copy` op.
 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
                                            SubViewOp subView,
-                                           OperationFolder &folder) {
+                                           OperationFolder *folder) {
   auto zero = constant_index(folder, 0);
   auto one = constant_index(folder, 1);
 
@@ -113,7 +113,7 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
 SmallVector<PromotionInfo, 8>
 mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
                               ArrayRef<Value *> subViews,
-                              OperationFolder &folder) {
+                              OperationFolder *folder) {
   if (subViews.empty())
     return {};
 
@@ -157,7 +157,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
 }
 
 static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
-                                   OperationFolder &folder) {
+                                   OperationFolder *folder) {
   // 1. Promote the specified views and use them in the new op.
   OpBuilder b(op);
   ScopedContext scope(b, op.getLoc());
@@ -211,7 +211,7 @@ static void promoteSubViews(FuncOp f) {
       if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
         subViews.insert(sv);
     if (!subViews.empty()) {
-      promoteSubViewOperands(op, subViews, folder);
+      promoteSubViewOperands(op, subViews, &folder);
       toErase.push_back(op);
     }
   });
index a499f34..c1d9755 100644 (file)
@@ -68,7 +68,7 @@ static bool isZero(Value *v) {
 static SmallVector<SubViewOp::Range, 4>
 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
                     ArrayRef<Value *> allViewSizes,
-                    ArrayRef<Value *> allTileSizes, OperationFolder &folder) {
+                    ArrayRef<Value *> allTileSizes, OperationFolder *folder) {
   assert(allTileSizes.size() == map.getNumResults());
   // Apply `map` to get view sizes in loop order.
   auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder);
@@ -141,7 +141,7 @@ static bool isTiled(AffineMap map, ArrayRef<Value *> tileSizes) {
 static SmallVector<Value *, 4>
 makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
                ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes,
-               ArrayRef<Value *> viewSizes, OperationFolder &folder) {
+               ArrayRef<Value *> viewSizes, OperationFolder *folder) {
   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
                            [](Value *v) { return !isZero(v); })) &&
@@ -211,17 +211,22 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
   }
 
   // Traverse the mins/maxes and erase those that don't have uses left.
-  mins.append(maxes.begin(), maxes.end());
-  for (auto *v : mins)
-    if (v->use_empty())
-      v->getDefiningOp()->erase();
+  // This is a special type of folding that we only apply when `folder` is
+  // defined.
+  if (folder) {
+    mins.append(maxes.begin(), maxes.end());
+    for (auto *v : mins)
+      if (v->use_empty())
+        v->getDefiningOp()->erase();
+  }
 
   return res;
 }
 
 llvm::Optional<TiledLinalgOp>
-mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
-                           OperationFolder &folder) {
+mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
+                           ArrayRef<Value *> tileSizes,
+                           OperationFolder *folder) {
   // 1. Enforce the convention that "tiling by zero" skips tiling a particular
   // dimension. This convention is significantly simpler to handle instead of
   // adjusting affine maps to account for missing dimensions.
@@ -229,9 +234,9 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
                  op.getNumWindowLoops() ==
              tileSizes.size() &&
          "expected matching number of tile sizes and loops");
-
-  OpBuilder builder(op.getOperation());
-  ScopedContext scope(builder, op.getLoc());
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
+  ScopedContext scope(b, op.getLoc());
   // 2. Build the tiled loop ranges.
   auto viewSizes = getViewSizes(op);
   // The flattened loopToOperandRangesMaps is expected to be an invertible
@@ -240,8 +245,8 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
       inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op)));
   assert(viewSizesToLoopsMap && "expected invertible map");
   auto loopRanges =
-      makeTiledLoopRanges(scope.getBuilder(), scope.getLocation(),
-                          viewSizesToLoopsMap, viewSizes, tileSizes, folder);
+      makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
+                          viewSizes, tileSizes, folder);
 
   // 3. Create the tiled loops.
   LinalgOp res = op;
@@ -268,8 +273,9 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
 }
 
 llvm::Optional<TiledLinalgOp>
-mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
-                           OperationFolder &folder) {
+mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
+                           ArrayRef<int64_t> tileSizes,
+                           OperationFolder *folder) {
   if (tileSizes.empty())
     return llvm::None;
 
@@ -284,8 +290,9 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
     return llvm::None;
 
   // Create a builder for tile size constants.
-  OpBuilder builder(op);
-  ScopedContext scope(builder, op.getLoc());
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
+  ScopedContext scope(b, op.getLoc());
 
   // Materialize concrete tile size values to pass the generic tiling function.
   SmallVector<Value *, 8> tileSizeValues;
@@ -298,13 +305,14 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
       tileSizeValues.push_back(constant_index(folder, 0));
   }
 
-  return tileLinalgOp(op, tileSizeValues, folder);
+  return tileLinalgOp(b, op, tileSizeValues, folder);
 }
 
 static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
+  OpBuilder b(f);
   OperationFolder folder(f.getContext());
-  f.walk([tileSizes, &folder](LinalgOp op) {
-    auto opLoopsPair = tileLinalgOp(op, tileSizes, folder);
+  f.walk([tileSizes, &b, &folder](LinalgOp op) {
+    auto opLoopsPair = tileLinalgOp(b, op, tileSizes, &folder);
     // If tiling occurred successfully, erase old op.
     if (opLoopsPair)
       op.erase();
index 7fefe5c..dcd2e56 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
 #include "mlir/EDSC/Helpers.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/OpImplementation.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/STLExtras.h"
 #include "mlir/Transforms/FoldUtils.h"
@@ -109,18 +109,18 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
 static Value *emitOrFoldComposedAffineApply(OpBuilder &b, Location loc,
                                             AffineMap map,
                                             ArrayRef<Value *> operandsRef,
-                                            OperationFolder &state) {
+                                            OperationFolder *folder) {
   SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
   fullyComposeAffineMapAndOperands(&map, &operands);
   canonicalizeMapAndOperands(&map, &operands);
-  return state.create<AffineApplyOp>(b, loc, map, operands);
+  return folder ? folder->create<AffineApplyOp>(b, loc, map, operands)
+                : b.create<AffineApplyOp>(loc, map, operands);
 }
 
-SmallVector<Value *, 4> mlir::linalg::applyMapToValues(OpBuilder &b,
-                                                       Location loc,
-                                                       AffineMap map,
-                                                       ArrayRef<Value *> values,
-                                                       OperationFolder &state) {
+SmallVector<Value *, 4>
+mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, AffineMap map,
+                               ArrayRef<Value *> values,
+                               OperationFolder *folder) {
   SmallVector<Value *, 4> res;
   res.reserve(map.getNumResults());
   unsigned numDims = map.getNumDims();
@@ -129,7 +129,7 @@ SmallVector<Value *, 4> mlir::linalg::applyMapToValues(OpBuilder &b,
   // folding occurs eagerly. Otherwise, an affine.apply operation is emitted.
   for (auto expr : map.getResults()) {
     AffineMap map = AffineMap::get(numDims, 0, expr);
-    res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, state));
+    res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, folder));
   }
   return res;
 }