[mlir] Add SubViewOp::getOrCreateRanges and fix folding pattern
authorNicolas Vasilache <ntv@google.com>
Wed, 13 May 2020 02:21:36 +0000 (22:21 -0400)
committerNicolas Vasilache <ntv@google.com>
Wed, 13 May 2020 14:11:30 +0000 (10:11 -0400)
The existing implementation of SubViewOp::getRanges relies on all
offsets/sizes/strides to be dynamic values and does not work in
combination with canonicalization. This revision adds a
SubViewOp::getOrCreateRanges to create the missing constants in the
canonicalized case.

This allows reactivating the fused pass with staged pattern
applications.

However another issue surfaces that the SubViewOp verifier is now too
strict to allow folding. The existing folding pattern is turned into a
canonicalization pattern which rewrites memref_cast + subview into
subview + memref_cast.

The transform-patterns-matmul-to-vector can then be reactivated.

Differential Revision: https://reviews.llvm.org/D79759

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
mlir/test/Transforms/canonicalize.mlir

index 30b5d43..e978323 100644 (file)
@@ -2676,8 +2676,18 @@ def SubViewOp : Std_Op<"subview", [
     struct Range {
       Value offset, size, stride;
     };
-    // TODO: retire `getRanges`.
-    SmallVector<Range, 8> getRanges();
+    /// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each
+    /// Range entry contains either the dynamic value or a ConstantIndexOp
+    /// constructed with `b` at location `loc`.
+    SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
+
+    /// A subview result type can be fully inferred from the source type and the
+    /// static representation of offsets, sizes and strides. Special sentinels
+    /// encode the dynamic case.
+    static Type inferSubViewResultType(MemRefType sourceMemRefType,
+                                       ArrayRef<int64_t> staticOffsets,
+                                       ArrayRef<int64_t> staticSizes,
+                                       ArrayRef<int64_t> staticStrides);
 
     /// Return the rank of the result MemRefType.
     unsigned getRank() { return getType().getRank(); }
@@ -2750,7 +2760,6 @@ def SubViewOp : Std_Op<"subview", [
   }];
 
   let hasCanonicalizer = 1;
-  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
index d541ed2..34fe059 100644 (file)
@@ -184,15 +184,16 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
   unsigned nWin = producer.getNumWindowLoops();
   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
 
+  OpBuilder b(consumer.getOperation());
+  auto loc = consumer.getLoc();
   // Iterate over dimensions identified by the producer map for `producerIdx`.
   // This defines a subset of the loop ranges that we need to complete later.
   for (auto en : llvm::enumerate(producerMap.getResults())) {
     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
-    loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
+    loopRanges[posInProducerLoop] =
+        subView.getOrCreateRanges(b, loc)[en.index()];
   }
 
-  OpBuilder b(consumer.getOperation());
-  auto loc = consumer.getLoc();
   // Iterate over all dimensions. For the dimensions not identified by the
   // producer map for `producerIdx`, we need to explicitly compute the view that
   // defines the loop ranges using the `producer`.
index 03f8d9e..5cbaa2f 100644 (file)
@@ -153,7 +153,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
   SmallVector<Value, 8> fullSizes, partialSizes;
   fullSizes.reserve(rank);
   partialSizes.reserve(rank);
-  for (auto en : llvm::enumerate(subView.getRanges())) {
+  for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
     auto rank = en.index();
     auto rangeValue = en.value();
     // Try to extract a tight constant.
@@ -169,7 +169,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
                             dynamicBuffers, folder, alignment);
   auto fullLocalView = folded_std_view(
       folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
-      folded_std_constant_index(folder, 0), fullSizes);
+      zero, fullSizes);
   SmallVector<Value, 4> zeros(fullSizes.size(), zero);
   SmallVector<Value, 4> ones(fullSizes.size(), one);
   auto partialLocalView =
index 7ca5e79..9cd97c3 100644 (file)
@@ -2275,10 +2275,10 @@ Wrapper operator*(Wrapper a, int64_t b) {
 /// A subview result type can be fully inferred from the source type and the
 /// static representation of offsets, sizes and strides. Special sentinels
 /// encode the dynamic case.
-static Type inferSubViewResultType(MemRefType sourceMemRefType,
-                                   ArrayRef<int64_t> staticOffsets,
-                                   ArrayRef<int64_t> staticSizes,
-                                   ArrayRef<int64_t> staticStrides) {
+Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
+                                       ArrayRef<int64_t> staticOffsets,
+                                       ArrayRef<int64_t> staticSizes,
+                                       ArrayRef<int64_t> staticStrides) {
   unsigned rank = sourceMemRefType.getRank();
   (void)rank;
   assert(staticOffsets.size() == rank &&
@@ -2474,7 +2474,7 @@ static LogicalResult verify(SubViewOp op) {
     return failure();
 
   // Verify result type against inferred type.
-  auto expectedType = inferSubViewResultType(
+  auto expectedType = SubViewOp::inferSubViewResultType(
       op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
@@ -2489,16 +2489,6 @@ raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
             << range.stride;
 }
 
-SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
-  SmallVector<Range, 8> res;
-  unsigned rank = getType().getRank();
-  res.reserve(rank);
-  for (unsigned i = 0; i < rank; ++i)
-    res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
-                           *(strides().begin() + i)});
-  return res;
-}
-
 static unsigned getNumDynamicEntriesUpToIdx(
     ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
   return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
@@ -2540,6 +2530,29 @@ unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
   return 1 + offsets().size() + sizes().size() + numDynamic;
 }
 
+/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
+/// entry contains either the dynamic value or a ConstantIndexOp constructed
+/// with `b` at location `loc`.
+SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
+                                                              Location loc) {
+  SmallVector<Range, 8> res;
+  unsigned rank = getType().getRank();
+  res.reserve(rank);
+  for (unsigned idx = 0; idx < rank; ++idx) {
+    auto offset = isDynamicOffset(idx)
+                      ? getDynamicOffset(idx)
+                      : b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
+    auto size = isDynamicSize(idx)
+                    ? getDynamicSize(idx)
+                    : b.create<ConstantIndexOp>(loc, getStaticSize(idx));
+    auto stride = isDynamicStride(idx)
+                      ? getDynamicStride(idx)
+                      : b.create<ConstantIndexOp>(loc, getStaticStride(idx));
+    res.emplace_back(Range{offset, size, stride});
+  }
+  return res;
+}
+
 LogicalResult
 SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
   if (!strides().empty())
@@ -2583,7 +2596,8 @@ void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
 }
 
 /// Pattern to rewrite a subview op with constant arguments.
-class SubViewOpFolder final : public OpRewritePattern<SubViewOp> {
+class SubViewOpConstantArgumentFolder final
+    : public OpRewritePattern<SubViewOp> {
 public:
   using OpRewritePattern<SubViewOp>::OpRewritePattern;
 
@@ -2718,27 +2732,63 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
   return true;
 }
 
-OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
-  auto folds = [](Operation *op) {
-    bool folded = false;
-    for (OpOperand &operand : op->getOpOperands()) {
-      auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
-      if (castOp && canFoldIntoConsumerOp(castOp)) {
-        operand.set(castOp.getOperand());
-        folded = true;
-      }
-    }
-    return folded ? success() : failure();
-  };
+/// Pattern to rewrite a subview op with MemRefCast arguments.
+/// This essentially pushes memref_cast past its consuming subview when
+/// `canFoldIntoConsumerOp` is true.
+///
+/// Example:
+/// ```
+///   %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
+///   %1 = subview %0[0, 0][3, 4][1, 1] :
+///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
+/// ```
+/// is rewritten into:
+/// ```
+///   %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
+///   %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
+///     memref<3x4xf32, offset:?, strides:[?, 1]>
+/// ```
+class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
+public:
+  using OpRewritePattern<SubViewOp>::OpRewritePattern;
 
-  if (succeeded(folds(*this)))
-    return getResult();
-  return {};
-}
+  LogicalResult matchAndRewrite(SubViewOp subViewOp,
+                                PatternRewriter &rewriter) const override {
+    // Any constant operand, just return to let SubViewOpConstantFolder kick in.
+    if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
+          return matchPattern(operand, m_ConstantIndex());
+        }))
+      return failure();
+
+    auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
+    if (!castOp)
+      return failure();
+
+    if (!canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
+    /// the cast source operand type and the SubViewOp static information. This
+    /// is the resulting type if the MemRefCastOp were folded.
+    Type resultType = SubViewOp::inferSubViewResultType(
+        castOp.source().getType().cast<MemRefType>(),
+        extractFromI64ArrayAttr(subViewOp.static_offsets()),
+        extractFromI64ArrayAttr(subViewOp.static_sizes()),
+        extractFromI64ArrayAttr(subViewOp.static_strides()));
+    Value newSubView = rewriter.create<SubViewOp>(
+        subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
+        subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
+        subViewOp.static_sizes(), subViewOp.static_strides());
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
+                                              newSubView);
+    return success();
+  }
+};
 
 void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
-  results.insert<SubViewOpFolder>(context);
+  results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
index 29ea43a..73c72ba 100644 (file)
@@ -1,7 +1,5 @@
-// TODO: this needs a fix to land before being reactivated.
-// RUN: ls
-// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
-// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
 
 func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                   %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
index f97cf21..76bd6b4 100644 (file)
@@ -941,3 +941,19 @@ func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<
   return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
 }
 
+// -----
+
+// CHECK-DAG: #[[map0:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
+// CHECK-DAG: #[[map1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+
+// CHECK-LABEL: func @memref_cast_folding_subview_static(
+func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: index)
+  -> memref<3x4xf32, offset:?, strides:[?, 1]>
+{
+  %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
+  %1 = subview %0[0, 0][3, 4][1, 1] : memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
+
+  // CHECK:  subview{{.*}}: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
+  // CHECK:  memref_cast{{.*}}: memref<3x4xf32, #[[map0]]> to memref<3x4xf32, #[[map1]]>
+  return %1: memref<3x4xf32, offset:?, strides:[?, 1]>
+}