Adds canonicalizer to SubViewOp which folds constants from base memref and operands...
authorAndy Davis <andydavis@google.com>
Thu, 14 Nov 2019 20:22:28 +0000 (12:22 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 14 Nov 2019 20:23:04 +0000 (12:23 -0800)
Changes SubViewOp to support zero operands case, when offset, strides and sizes are all constant.

PiperOrigin-RevId: 280485075

mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/test/IR/core-ops.mlir
mlir/test/Transforms/canonicalize.mlir

index bfd2452..281707e 100644 (file)
@@ -1222,12 +1222,17 @@ def SubViewOp : Std_Op<"subview", [SameVariadicOperandSize]> {
 
     The SubView operation supports the following arguments:
     *) Memref: the "base" memref on which to create a "view" memref.
-    *) Offsets: memref-rank number of dynamic offsets into the "base" memref at
-                which to create the "view" memref.
-    *) Sizes: memref-rank dynamic size operands which specify the dynamic sizes
-              of the result "view" memref type.
-    *) Strides: memref-rank number of dynamic strides which are applied
+    *) Offsets: zero or memref-rank number of dynamic offsets into the "base"
+                memref at which to create the "view" memref.
+    *) Sizes: zero or memref-rank dynamic size operands which specify the
+              dynamic sizes of the result "view" memref type.
+    *) Strides: zero or memref-rank number of dynamic strides which are applied
                 multiplicatively to the base memref strides in each dimension.
+    Note on the number of operands for offsets, sizes and strides: either
+    memref-rank number of operands must be set for each of offsets, sizes and
+    strides, or zero operands must be specified for offsets, sizes and strides
+    (in which case the base and subview memrefs must all have constant offset
+     sizes and strides).
 
     Example 1:
 
@@ -1254,6 +1259,15 @@ def SubViewOp : Std_Op<"subview", [SameVariadicOperandSize]> {
         : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
           memref<?x?x?xf32,
             (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
+
+    Example 3:
+
+      %0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)>
+
+      // Subview with constant offsets, sizes and strides.
+      %1 = subview %0[][][]
+        : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
+          memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>
     }
   }];
 
@@ -1265,34 +1279,31 @@ def SubViewOp : Std_Op<"subview", [SameVariadicOperandSize]> {
     "Builder *b, OperationState &result, Value *source, "
     "ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
     "ArrayRef<Value *> strides, Type resultType = Type(), "
-    "ArrayRef<NamedAttribute> attrs = {}">];
+    "ArrayRef<NamedAttribute> attrs = {}">,
+    OpBuilder<
+    "Builder *builder, OperationState &result, Type resultType, Value *source",
+    [{
+      result.addOperands(source);
+      result.addTypes(resultType);
+    }]>];
 
   let extraClassDeclaration = [{
+    /// Returns the type of the base memref operand.
+    MemRefType getBaseMemRefType() {
+      return source()->getType().cast<MemRefType>();
+    }
+
     /// The result of a subview is always a memref.
     MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
 
     /// Returns the dynamic offsets for this subview operation.
-    operand_range getDynamicOffsets() {
-      return {operand_begin() + 1, operand_begin() + 1 + getType().getRank()};
-    }
-
-    /// Returns the operand starting position of the size operands.
-    unsigned getSizeOperandsStart() { return 1 + getType().getRank(); }
+    operand_range getDynamicOffsets();
 
     /// Returns the dynamic sizes for this subview operation if specified.
-    operand_range getDynamicSizes() {
-      return {operand_begin() + getSizeOperandsStart(),
-              operand_begin() + getSizeOperandsStart() + getType().getRank()};
-    }
-
-    /// Returns the operand starting position of the size operands.
-    unsigned getStrideOperandsStart() { return 1 + 2 * getType().getRank(); }
+    operand_range getDynamicSizes();
 
     /// Returns the dynamic strides for this subview operation if specified.
-    operand_range getDynamicStrides() {
-      return {operand_begin() + getStrideOperandsStart(),
-              operand_begin() + getStrideOperandsStart() + getType().getRank()};
-    }
+    operand_range getDynamicStrides();
 
     // Auxiliary range data structure and helper function that unpacks the
     // offset, size and stride operands of the SubViewOp into a list of triples.
@@ -1303,7 +1314,7 @@ def SubViewOp : Std_Op<"subview", [SameVariadicOperandSize]> {
     SmallVector<Range, 8> getRanges();
   }];
 
-  // TODO(andydavis) Add canonicalizer.
+  let hasCanonicalizer = 1;
 }
 
 def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
index e70675e..bf0cb75 100644 (file)
@@ -2533,11 +2533,11 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
            dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
 
     // Create new ViewOp.
-    auto newShapeCastOp = rewriter.create<ViewOp>(
-        viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), newOperands);
+    auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
+                                             viewOp.getOperand(0), newOperands);
     // Insert a cast so we have the same type as the old memref type.
     rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp,
-                                              newShapeCastOp, viewOp.getType());
+                                              newViewOp, viewOp.getType());
     return matchSuccess();
   }
 };
@@ -2658,7 +2658,8 @@ static LogicalResult verify(SubViewOp op) {
            << subViewType;
 
   // Verify that the subview layout map has a dynamic offset.
-  if (subViewOffset != MemRefType::getDynamicStrideOrOffset())
+  if (op.getNumOperands() > 1 &&
+      subViewOffset != MemRefType::getDynamicStrideOrOffset())
     return op.emitError("subview memref layout map must specify a dynamic "
                         "offset for type ")
            << subViewType;
@@ -2688,6 +2689,162 @@ SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
   return res;
 }
 
+static bool hasConstantOffsetSizesAndStrides(MemRefType memrefType) {
+  if (memrefType.getNumDynamicDims() > 0)
+    return false;
+  // Get offset and strides.
+  int64_t offset;
+  llvm::SmallVector<int64_t, 4> strides;
+  if (failed(getStridesAndOffset(memrefType, strides, offset)))
+    return false;
+  // Return 'false' if any of offset or strides is dynamic.
+  if (offset == MemRefType::getDynamicStrideOrOffset() ||
+      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
+    return false;
+  return true;
+}
+
+namespace {
+
+struct SubViewOpShapeFolder : public OpRewritePattern<SubViewOp> {
+  using OpRewritePattern<SubViewOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
+                                     PatternRewriter &rewriter) const override {
+    // Get base memref type.
+    auto baseMemrefType = subViewOp.getBaseMemRefType();
+    if (baseMemrefType.getAffineMaps().size() != 1)
+      return matchFailure();
+    auto baseMap = baseMemrefType.getAffineMaps()[0];
+
+    // Get base memref offsets and strides.
+    int64_t baseOffset;
+    llvm::SmallVector<int64_t, 4> baseStrides;
+    if (failed(getStridesAndOffset(baseMemrefType, baseStrides, baseOffset)))
+      return matchFailure();
+
+    // Keep it simple for now: return if any of the base memrefs offset, sizes
+    // or strides is dynamic.
+    if (baseOffset == MemRefType::getDynamicStrideOrOffset() ||
+        baseMemrefType.getNumDynamicDims() > 0 ||
+        llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()))
+      return matchFailure();
+
+    // Get subView memref type.
+    auto subViewMemrefType = subViewOp.getType();
+    if (subViewMemrefType.getAffineMaps().size() != 1)
+      return matchFailure();
+    auto subViewMap = subViewMemrefType.getAffineMaps()[0];
+
+    // Return if the subViewOp has already been constant folded.
+    if (subViewOp.getNumOperands() == 1) {
+      assert(hasConstantOffsetSizesAndStrides(subViewMemrefType));
+      return matchFailure();
+    }
+
+    // Keep it simple for now: return if any view memref operands are dynamic.
+    SmallVector<Value *, 4> operands(subViewOp.getOperands().begin(),
+                                     subViewOp.getOperands().end());
+    ArrayRef<Value *> operandsRef(operands);
+    if (llvm::any_of(operandsRef.drop_front(), [](Value *operand) {
+          return !matchPattern(operand, m_ConstantIndex());
+        }))
+      return matchFailure();
+
+    // Compute new subview offset based on base memref strides.
+    int64_t newSubViewOffset = baseOffset;
+    SmallVector<Value *, 4> offsets(subViewOp.getDynamicOffsets().begin(),
+                                    subViewOp.getDynamicOffsets().end());
+    assert(offsets.size() == baseStrides.size());
+    for (unsigned i = 0, e = offsets.size(); i < e; ++i) {
+      auto constantOffsetOp =
+          cast<ConstantIndexOp>(offsets[i]->getDefiningOp());
+      newSubViewOffset += constantOffsetOp.getValue() * baseStrides[i];
+    }
+
+    // Fold any dynamic dim operands which are produced by a constant.
+    SmallVector<int64_t, 4> newShapeConstants;
+    newShapeConstants.reserve(subViewMemrefType.getRank());
+
+    unsigned dynamicDimPos = 1 + subViewMemrefType.getRank();
+    unsigned rank = subViewMemrefType.getRank();
+    for (unsigned dim = 0, e = rank; dim < e; ++dim) {
+      int64_t dimSize = subViewMemrefType.getDimSize(dim);
+      // SubViewOp shape folding currently folds everything or nothing, so we
+      // expect all dynamic sizes at this point.
+      assert(ShapedType::isDynamic(dimSize));
+      (void)dimSize;
+
+      auto *defOp = subViewOp.getOperand(dynamicDimPos)->getDefiningOp();
+      assert(defOp != nullptr);
+      assert(isa<ConstantIndexOp>(defOp));
+      auto constantSizeOp = cast<ConstantIndexOp>(defOp);
+      // Dynamic shape dimension will be folded.
+      newShapeConstants.push_back(constantSizeOp.getValue());
+      dynamicDimPos++;
+    }
+
+    // Compute new strides based on 'newShapeConstants'.
+    SmallVector<int64_t, 4> newSubViewStrides(rank);
+    newSubViewStrides[rank - 1] = 1;
+    for (int i = rank - 2; i >= 0; --i) {
+      assert(!ShapedType::isDynamic(newShapeConstants[i + 1]));
+      newSubViewStrides[i] =
+          newShapeConstants[i + 1] * newSubViewStrides[i + 1];
+    }
+
+    // Regenerate strided layout map with 'newSubViewStrides' and
+    // 'newSubViewOffset'.
+    subViewMap = makeStridedLinearLayoutMap(newSubViewStrides, newSubViewOffset,
+                                            rewriter.getContext());
+
+    // Create new memref type with constant folded dims and/or offset/strides.
+    auto newMemRefType =
+        MemRefType::get(newShapeConstants, subViewMemrefType.getElementType(),
+                        {subViewMap}, subViewMemrefType.getMemorySpace());
+
+    // Create new SubViewOp.
+    auto newSubViewOp = rewriter.create<SubViewOp>(
+        subViewOp.getLoc(), newMemRefType, subViewOp.getOperand(0));
+    // Insert a cast so we have the same type as the old memref type.
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(
+        operandsRef.drop_front(), subViewOp, newSubViewOp, subViewOp.getType());
+    return matchSuccess();
+  }
+};
+
+} // end anonymous namespace
+
+SubViewOp::operand_range SubViewOp::getDynamicOffsets() {
+  if (hasConstantOffsetSizesAndStrides(getBaseMemRefType()) &&
+      hasConstantOffsetSizesAndStrides(getType()))
+    return {operand_end(), operand_end()};
+  return {operand_begin() + 1, operand_begin() + 1 + getType().getRank()};
+}
+
+SubViewOp::operand_range SubViewOp::getDynamicSizes() {
+  if (hasConstantOffsetSizesAndStrides(getBaseMemRefType()) &&
+      hasConstantOffsetSizesAndStrides(getType()))
+    return {operand_end(), operand_end()};
+  unsigned sizesOperandsStart = 1 + getType().getRank();
+  return {operand_begin() + sizesOperandsStart,
+          operand_begin() + sizesOperandsStart + getType().getRank()};
+}
+
+SubViewOp::operand_range SubViewOp::getDynamicStrides() {
+  if (hasConstantOffsetSizesAndStrides(getBaseMemRefType()) &&
+      hasConstantOffsetSizesAndStrides(getType()))
+    return {operand_end(), operand_end()};
+  unsigned stridesOperandsStart = 1 + 2 * getType().getRank();
+  return {operand_begin() + stridesOperandsStart,
+          operand_begin() + stridesOperandsStart + getType().getRank()};
+}
+
+void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                            MLIRContext *context) {
+  results.insert<SubViewOpShapeFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ZeroExtendIOp
 //===----------------------------------------------------------------------===//
index 417c872..fd2d442 100644 (file)
@@ -20,6 +20,7 @@
 
 // CHECK-DAG: #[[BASE_MAP2:map[0-9]+]] = (d0, d1) -> (d0 * 22 + d1)
 // CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)
+// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)
 
 // CHECK-LABEL: func @func_with_ops(%arg0: f32) {
 func @func_with_ops(f32) {
@@ -517,8 +518,6 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
 
-  //%2 = alloc() : memref<64xf32, (d0) -> (d0)>
-
   %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>
   // CHECK: std.subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP0]]>
   %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1]
@@ -537,6 +536,12 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
   %5 = subview %4[%c0, %c1][%arg0, %arg1][%c1, %c0]
     : memref<64x22xf32, (d0, d1) -> (d0 * 22 + d1)> to
       memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> 
+
+  // CHECK: std.subview %0[][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]>
+  %6 = subview %0[][][]
+    : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
+      memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>
+
   return
 }
 
index bd2d7de..9bef050 100644 (file)
 // CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = (d0, d1, d2)[s0] -> (d0 * s0 + d1 * 7 + d2)
 // CHECK-DAG: #[[VIEW_MAP4:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1 + 15)
 
+// CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)
+// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = (d0, d1, d2) -> (d0 * 165 + d1 * 15 + d2)
+// CHECK-DAG: #[[SUBVIEW_MAP1:map[0-9]+]] = (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)
+
 // CHECK-LABEL: func @test_subi_zero
 func @test_subi_zero(%arg0: i32) -> i32 {
   // CHECK-NEXT: %c0_i32 = constant 0 : i32
@@ -673,3 +677,39 @@ func @view(%arg0 : index) {
 
   return
 }
+
+
+// CHECK-LABEL: func @subview
+func @subview(%arg0 : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c7 = constant 7 : index
+  %c11 = constant 11 : index
+  %c15 = constant 15 : index
+
+  %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>
+
+  // Test: subview with constant base memref and constant operands is folded.
+  // CHECK: std.subview %0[][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]>
+  %1 = subview %0[%c0, %c0, %c0][%c7, %c11, %c15][%c1, %c1, %c1]
+    : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
+      memref<?x?x?xf32,
+       (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
+
+  // Test: subview with one dynamic operand should not be folded.
+  // CHECK: std.subview %0[%c0, %arg0, %c0][%c7, %c11, %c15][%c1, %c1, %c1] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP1]]> 
+  %2 = subview %0[%c0, %arg0, %c0][%c7, %c11, %c15][%c1, %c1, %c1]
+    : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
+      memref<?x?x?xf32,
+       (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
+
+  // Test: subview with constant operands but dynamic base memref is not folded.
+  %3 = alloc(%arg0) : memref<?x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>
+  // CHECK: std.subview %3[%c0, %c0, %c0][%c7, %c11, %c15][%c1, %c1, %c1] : memref<?x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP1]]> 
+  %4 = subview %3[%c0, %c0, %c0][%c7, %c11, %c15][%c1, %c1, %c1]
+    : memref<?x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
+      memref<?x?x?xf32,
+       (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
+
+  return
+}