From a4669cd3b49f1e7b7b06bcc5602ab00da5b99efb Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Thu, 14 Nov 2019 12:22:28 -0800 Subject: [PATCH] Adds canonicalizer to SubViewOp which folds constants from base memref and operands into the subview result memref type. Changes SubViewOp to support zero operands case, when offset, strides and sizes are all constant. PiperOrigin-RevId: 280485075 --- mlir/include/mlir/Dialect/StandardOps/Ops.td | 59 ++++++---- mlir/lib/Dialect/StandardOps/Ops.cpp | 165 ++++++++++++++++++++++++++- mlir/test/IR/core-ops.mlir | 9 +- mlir/test/Transforms/canonicalize.mlir | 40 +++++++ 4 files changed, 243 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index bfd2452..281707e 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1222,12 +1222,17 @@ def SubViewOp : Std_Op<"subview", [SameVariadicOperandSize]> { The SubView operation supports the following arguments: *) Memref: the "base" memref on which to create a "view" memref. - *) Offsets: memref-rank number of dynamic offsets into the "base" memref at - which to create the "view" memref. - *) Sizes: memref-rank dynamic size operands which specify the dynamic sizes - of the result "view" memref type. - *) Strides: memref-rank number of dynamic strides which are applied + *) Offsets: zero or memref-rank number of dynamic offsets into the "base" + memref at which to create the "view" memref. + *) Sizes: zero or memref-rank dynamic size operands which specify the + dynamic sizes of the result "view" memref type. + *) Strides: zero or memref-rank number of dynamic strides which are applied multiplicatively to the base memref strides in each dimension. + Note on the number of operands for offsets, sizes and strides: either + memref-rank number of operands must be set for each of offsets, sizes and + strides, or zero operands must be specified for offsets, sizes and strides + (in which case the base and subview memrefs must all have constant offset + sizes and strides). Example 1: @@ -1254,6 +1259,15 @@ def SubViewOp : Std_Op<"subview", [SameVariadicOperandSize]> { : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)> + + Example 3: + + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)> + + // Subview with constant offsets, sizes and strides. + %1 = subview %0[][][] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)> } }]; @@ -1265,34 +1279,31 @@ def SubViewOp : Std_Op<"subview", [SameVariadicOperandSize]> { "Builder *b, OperationState &result, Value *source, " "ArrayRef offsets, ArrayRef sizes, " "ArrayRef strides, Type resultType = Type(), " - "ArrayRef attrs = {}">]; + "ArrayRef attrs = {}">, + OpBuilder< + "Builder *builder, OperationState &result, Type resultType, Value *source", + [{ + result.addOperands(source); + result.addTypes(resultType); + }]>]; let extraClassDeclaration = [{ + /// Returns the type of the base memref operand. + MemRefType getBaseMemRefType() { + return source()->getType().cast(); + } + /// The result of a subview is always a memref. MemRefType getType() { return getResult()->getType().cast(); } /// Returns the dynamic offsets for this subview operation. - operand_range getDynamicOffsets() { - return {operand_begin() + 1, operand_begin() + 1 + getType().getRank()}; - } - - /// Returns the operand starting position of the size operands. - unsigned getSizeOperandsStart() { return 1 + getType().getRank(); } + operand_range getDynamicOffsets(); /// Returns the dynamic sizes for this subview operation if specified. - operand_range getDynamicSizes() { - return {operand_begin() + getSizeOperandsStart(), - operand_begin() + getSizeOperandsStart() + getType().getRank()}; - } - - /// Returns the operand starting position of the size operands. - unsigned getStrideOperandsStart() { return 1 + 2 * getType().getRank(); } + operand_range getDynamicSizes(); /// Returns the dynamic strides for this subview operation if specified. - operand_range getDynamicStrides() { - return {operand_begin() + getStrideOperandsStart(), - operand_begin() + getStrideOperandsStart() + getType().getRank()}; - } + operand_range getDynamicStrides(); // Auxiliary range data structure and helper function that unpacks the // offset, size and stride operands of the SubViewOp into a list of triples. @@ -1303,7 +1314,7 @@ def SubViewOp : Std_Op<"subview", [SameVariadicOperandSize]> { SmallVector getRanges(); }]; - // TODO(andydavis) Add canonicalizer. + let hasCanonicalizer = 1; } def XOrOp : IntArithmeticOp<"xor", [Commutative]> { diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index e70675e..bf0cb75 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2533,11 +2533,11 @@ struct ViewOpShapeFolder : public OpRewritePattern { dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims()); // Create new ViewOp. - auto newShapeCastOp = rewriter.create( - viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), newOperands); + auto newViewOp = rewriter.create(viewOp.getLoc(), newMemRefType, + viewOp.getOperand(0), newOperands); // Insert a cast so we have the same type as the old memref type. rewriter.replaceOpWithNewOp(droppedOperands, viewOp, - newShapeCastOp, viewOp.getType()); + newViewOp, viewOp.getType()); return matchSuccess(); } }; @@ -2658,7 +2658,8 @@ static LogicalResult verify(SubViewOp op) { << subViewType; // Verify that the subview layout map has a dynamic offset. - if (subViewOffset != MemRefType::getDynamicStrideOrOffset()) + if (op.getNumOperands() > 1 && + subViewOffset != MemRefType::getDynamicStrideOrOffset()) return op.emitError("subview memref layout map must specify a dynamic " "offset for type ") << subViewType; @@ -2688,6 +2689,162 @@ SmallVector SubViewOp::getRanges() { return res; } +static bool hasConstantOffsetSizesAndStrides(MemRefType memrefType) { + if (memrefType.getNumDynamicDims() > 0) + return false; + // Get offset and strides. + int64_t offset; + llvm::SmallVector strides; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + return false; + // Return 'false' if any of offset or strides is dynamic. + if (offset == MemRefType::getDynamicStrideOrOffset() || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) + return false; + return true; +} + +namespace { + +struct SubViewOpShapeFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { + // Get base memref type. + auto baseMemrefType = subViewOp.getBaseMemRefType(); + if (baseMemrefType.getAffineMaps().size() != 1) + return matchFailure(); + auto baseMap = baseMemrefType.getAffineMaps()[0]; + + // Get base memref offsets and strides. + int64_t baseOffset; + llvm::SmallVector baseStrides; + if (failed(getStridesAndOffset(baseMemrefType, baseStrides, baseOffset))) + return matchFailure(); + + // Keep it simple for now: return if any of the base memrefs offset, sizes + // or strides is dynamic. + if (baseOffset == MemRefType::getDynamicStrideOrOffset() || + baseMemrefType.getNumDynamicDims() > 0 || + llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset())) + return matchFailure(); + + // Get subView memref type. + auto subViewMemrefType = subViewOp.getType(); + if (subViewMemrefType.getAffineMaps().size() != 1) + return matchFailure(); + auto subViewMap = subViewMemrefType.getAffineMaps()[0]; + + // Return if the subViewOp has already been constant folded. + if (subViewOp.getNumOperands() == 1) { + assert(hasConstantOffsetSizesAndStrides(subViewMemrefType)); + return matchFailure(); + } + + // Keep it simple for now: return if any view memref operands are dynamic. + SmallVector operands(subViewOp.getOperands().begin(), + subViewOp.getOperands().end()); + ArrayRef operandsRef(operands); + if (llvm::any_of(operandsRef.drop_front(), [](Value *operand) { + return !matchPattern(operand, m_ConstantIndex()); + })) + return matchFailure(); + + // Compute new subview offset based on base memref strides. + int64_t newSubViewOffset = baseOffset; + SmallVector offsets(subViewOp.getDynamicOffsets().begin(), + subViewOp.getDynamicOffsets().end()); + assert(offsets.size() == baseStrides.size()); + for (unsigned i = 0, e = offsets.size(); i < e; ++i) { + auto constantOffsetOp = + cast(offsets[i]->getDefiningOp()); + newSubViewOffset += constantOffsetOp.getValue() * baseStrides[i]; + } + + // Fold any dynamic dim operands which are produced by a constant. + SmallVector newShapeConstants; + newShapeConstants.reserve(subViewMemrefType.getRank()); + + unsigned dynamicDimPos = 1 + subViewMemrefType.getRank(); + unsigned rank = subViewMemrefType.getRank(); + for (unsigned dim = 0, e = rank; dim < e; ++dim) { + int64_t dimSize = subViewMemrefType.getDimSize(dim); + // SubViewOp shape folding currently folds everything or nothing, so we + // expect all dynamic sizes at this point. + assert(ShapedType::isDynamic(dimSize)); + (void)dimSize; + + auto *defOp = subViewOp.getOperand(dynamicDimPos)->getDefiningOp(); + assert(defOp != nullptr); + assert(isa(defOp)); + auto constantSizeOp = cast(defOp); + // Dynamic shape dimension will be folded. + newShapeConstants.push_back(constantSizeOp.getValue()); + dynamicDimPos++; + } + + // Compute new strides based on 'newShapeConstants'. + SmallVector newSubViewStrides(rank); + newSubViewStrides[rank - 1] = 1; + for (int i = rank - 2; i >= 0; --i) { + assert(!ShapedType::isDynamic(newShapeConstants[i + 1])); + newSubViewStrides[i] = + newShapeConstants[i + 1] * newSubViewStrides[i + 1]; + } + + // Regenerate strided layout map with 'newSubViewStrides' and + // 'newSubViewOffset'. + subViewMap = makeStridedLinearLayoutMap(newSubViewStrides, newSubViewOffset, + rewriter.getContext()); + + // Create new memref type with constant folded dims and/or offset/strides. + auto newMemRefType = + MemRefType::get(newShapeConstants, subViewMemrefType.getElementType(), + {subViewMap}, subViewMemrefType.getMemorySpace()); + + // Create new SubViewOp. + auto newSubViewOp = rewriter.create( + subViewOp.getLoc(), newMemRefType, subViewOp.getOperand(0)); + // Insert a cast so we have the same type as the old memref type. + rewriter.replaceOpWithNewOp( + operandsRef.drop_front(), subViewOp, newSubViewOp, subViewOp.getType()); + return matchSuccess(); + } +}; + +} // end anonymous namespace + +SubViewOp::operand_range SubViewOp::getDynamicOffsets() { + if (hasConstantOffsetSizesAndStrides(getBaseMemRefType()) && + hasConstantOffsetSizesAndStrides(getType())) + return {operand_end(), operand_end()}; + return {operand_begin() + 1, operand_begin() + 1 + getType().getRank()}; +} + +SubViewOp::operand_range SubViewOp::getDynamicSizes() { + if (hasConstantOffsetSizesAndStrides(getBaseMemRefType()) && + hasConstantOffsetSizesAndStrides(getType())) + return {operand_end(), operand_end()}; + unsigned sizesOperandsStart = 1 + getType().getRank(); + return {operand_begin() + sizesOperandsStart, + operand_begin() + sizesOperandsStart + getType().getRank()}; +} + +SubViewOp::operand_range SubViewOp::getDynamicStrides() { + if (hasConstantOffsetSizesAndStrides(getBaseMemRefType()) && + hasConstantOffsetSizesAndStrides(getType())) + return {operand_end(), operand_end()}; + unsigned stridesOperandsStart = 1 + 2 * getType().getRank(); + return {operand_begin() + stridesOperandsStart, + operand_begin() + stridesOperandsStart + getType().getRank()}; +} + +void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 417c872..fd2d442 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -20,6 +20,7 @@ // CHECK-DAG: #[[BASE_MAP2:map[0-9]+]] = (d0, d1) -> (d0 * 22 + d1) // CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0) +// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8) // CHECK-LABEL: func @func_with_ops(%arg0: f32) { func @func_with_ops(f32) { @@ -517,8 +518,6 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %c0 = constant 0 : index %c1 = constant 1 : index - //%2 = alloc() : memref<64xf32, (d0) -> (d0)> - %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> // CHECK: std.subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] @@ -537,6 +536,12 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %5 = subview %4[%c0, %c1][%arg0, %arg1][%c1, %c0] : memref<64x22xf32, (d0, d1) -> (d0 * 22 + d1)> to memref (d0 * s1 + d1 * s2 + s0)> + + // CHECK: std.subview %0[][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]> + %6 = subview %0[][][] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)> + return } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index bd2d7de..9bef050 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -10,6 +10,10 @@ // CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = (d0, d1, d2)[s0] -> (d0 * s0 + d1 * 7 + d2) // CHECK-DAG: #[[VIEW_MAP4:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1 + 15) +// CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2) +// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = (d0, d1, d2) -> (d0 * 165 + d1 * 15 + d2) +// CHECK-DAG: #[[SUBVIEW_MAP1:map[0-9]+]] = (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0) + // CHECK-LABEL: func @test_subi_zero func @test_subi_zero(%arg0: i32) -> i32 { // CHECK-NEXT: %c0_i32 = constant 0 : i32 @@ -673,3 +677,39 @@ func @view(%arg0 : index) { return } + + +// CHECK-LABEL: func @subview +func @subview(%arg0 : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c7 = constant 7 : index + %c11 = constant 11 : index + %c15 = constant 15 : index + + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> + + // Test: subview with constant base memref and constant operands is folded. + // CHECK: std.subview %0[][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]> + %1 = subview %0[%c0, %c0, %c0][%c7, %c11, %c15][%c1, %c1, %c1] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)> + + // Test: subview with one dynamic operand should not be folded. + // CHECK: std.subview %0[%c0, %arg0, %c0][%c7, %c11, %c15][%c1, %c1, %c1] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref + %2 = subview %0[%c0, %arg0, %c0][%c7, %c11, %c15][%c1, %c1, %c1] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)> + + // Test: subview with constant operands but dynamic base memref is not folded. + %3 = alloc(%arg0) : memref (d0 * 64 + d1 * 4 + d2)> + // CHECK: std.subview %3[%c0, %c0, %c0][%c7, %c11, %c15][%c1, %c1, %c1] : memref to memref + %4 = subview %3[%c0, %c0, %c0][%c7, %c11, %c15][%c1, %c1, %c1] + : memref (d0 * 64 + d1 * 4 + d2)> to + memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)> + + return +} -- 2.7.4