From bd7de6d4dfb82df21f36c1c331cba87a4d0118f7 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 17 Sep 2019 11:49:14 -0700 Subject: [PATCH] Add rewrite pattern to compose maps into affine load/stores - add canonicalization pattern to compose maps into affine loads/stores; templatize the pattern and reuse it for affine.apply as well - rename getIndices -> getMapOperands() (getIndices is confusing since these are no longer the indices themselves but operands to the map whose results are the indices). This also makes the accessor uniform across affine.apply/load/store. Change arg names on the affine load/store builder to avoid confusion. Drop an unused confusing build method on AffineStoreOp. - update incomplete doc comment for canonicalizeMapAndOperands (this was missed from a previous update). Addresses issue tensorflow/mlir#121 Signed-off-by: Uday Bondhugula Closes tensorflow/mlir#122 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/122 from bondhugula:compose-load-store e71de1771e56a85c4282c10cb43f30cef0701c4f PiperOrigin-RevId: 269619540 --- mlir/include/mlir/Dialect/AffineOps/AffineOps.h | 26 +++--- mlir/lib/Analysis/LoopAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 4 +- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 103 +++++++++++++++++------- mlir/lib/Transforms/LowerAffine.cpp | 4 +- mlir/lib/Transforms/Vectorize.cpp | 10 ++- mlir/test/AffineOps/canonicalize.mlir | 27 +++++++ mlir/test/Transforms/canonicalize.mlir | 5 +- 8 files changed, 130 insertions(+), 51 deletions(-) diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h index 60a1c68..896299e 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -83,6 +83,8 @@ public: static StringRef getOperationName() { return "affine.apply"; } + operand_range getMapOperands() { return getOperands(); } + // Hooks to customize behavior of this op. static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); @@ -400,9 +402,12 @@ public: /// Builds an affine load op with the specified map and operands. static void build(Builder *builder, OperationState *result, AffineMap map, ArrayRef operands); - /// Builds an affine load op an identify map and operands. + /// Builds an affine load op with an identity map and operands. static void build(Builder *builder, OperationState *result, Value *memref, ArrayRef indices = {}); + /// Builds an affine load op with the specified map and its operands. + static void build(Builder *builder, OperationState *result, Value *memref, + AffineMap map, ArrayRef mapOperands); /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 0; } @@ -415,7 +420,7 @@ public: } /// Get affine map operands. - operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); } + operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); } /// Returns the affine map used to index the memref for this operation. AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } @@ -462,14 +467,14 @@ class AffineStoreOp : public Op operands); - /// Builds an affine store operation with an identity map and operands. + /// Builds an affine store operation with the provided indices (identity map). static void build(Builder *builder, OperationState *result, Value *valueToStore, Value *memref, - ArrayRef operands); + ArrayRef indices); + /// Builds an affine store operation with the specified map and its operands. + static void build(Builder *builder, OperationState *result, + Value *valueToStore, Value *memref, AffineMap map, + ArrayRef mapOperands); /// Get value to be stored by store operation. Value *getValueToStore() { return getOperand(0); } @@ -486,7 +491,7 @@ public: } /// Get affine map operands. - operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); } + operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); } /// Returns the affine map used to index the memref for this operation. AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } @@ -521,6 +526,9 @@ bool isValidSymbol(Value *value); /// Modifies both `map` and `operands` in-place so as to: /// 1. drop duplicate operands /// 2. drop unused dims and symbols from map +/// 3. promote valid symbols to symbolic operands in case they appeared as +/// dimensional operands +/// 4. propagate constant operands and drop them void canonicalizeMapAndOperands(AffineMap *map, llvm::SmallVectorImpl *operands); /// Canonicalizes an integer set the same way canonicalizeMapAndOperands does diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 45012b0..b1895d3 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -236,7 +236,7 @@ static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp, int uniqueVaryingIndexAlongIv = -1; auto accessMap = memoryOp.getAffineMap(); - SmallVector mapOperands(memoryOp.getIndices()); + SmallVector mapOperands(memoryOp.getMapOperands()); unsigned numDims = accessMap.getNumDims(); for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) { // Gather map operands used result expr 'i' in 'exprOperands'. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 660b77e..3f94aa8 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -847,7 +847,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp.getMemRefType(); indices.reserve(loadMemrefType.getRank()); - for (auto *index : loadOp.getIndices()) { + for (auto *index : loadOp.getMapOperands()) { indices.push_back(index); } } else { @@ -857,7 +857,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { memref = storeOp.getMemRef(); auto storeMemrefType = storeOp.getMemRefType(); indices.reserve(storeMemrefType.getRank()); - for (auto *index : storeOp.getIndices()) { + for (auto *index : storeOp.getMapOperands()) { indices.push_back(index); } } diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 2c98806..77df364 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -698,30 +698,63 @@ void mlir::canonicalizeSetAndOperands( } namespace { -/// Simplify AffineApply operations. +/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing +/// maps that supply results into them. /// -struct SimplifyAffineApply : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct SimplifyAffineOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineApplyOp apply, - PatternRewriter &rewriter) const override { - auto map = apply.getAffineMap(); + void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, + AffineMap map, ArrayRef mapOperands) const; + PatternMatchResult matchAndRewrite(AffineOpTy affineOp, + PatternRewriter &rewriter) const override { + static_assert(std::is_same::value || + std::is_same::value || + std::is_same::value, + "affine load/store/apply op expected"); + auto map = affineOp.getAffineMap(); AffineMap oldMap = map; - SmallVector resultOperands(apply.getOperands()); + auto oldOperands = affineOp.getMapOperands(); + SmallVector resultOperands(oldOperands); composeAffineMapAndOperands(&map, &resultOperands); - if (map == oldMap) - return matchFailure(); + if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), + resultOperands.begin())) + return this->matchFailure(); - rewriter.replaceOpWithNewOp(apply, map, resultOperands); - return matchSuccess(); + replaceAffineOp(rewriter, affineOp, map, resultOperands); + return this->matchSuccess(); } }; + +// Specialize the template to account for the different build signatures for +// affine load, store, and apply ops. +template <> +void SimplifyAffineOp::replaceAffineOp( + PatternRewriter &rewriter, AffineLoadOp load, AffineMap map, + ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp(load, load.getMemRef(), map, + mapOperands); +} +template <> +void SimplifyAffineOp::replaceAffineOp( + PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, + ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp( + store, store.getValueToStore(), store.getMemRef(), map, mapOperands); +} +template <> +void SimplifyAffineOp::replaceAffineOp( + PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map, + ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp(apply, map, mapOperands); +} } // end anonymous namespace. void AffineApplyOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); } //===----------------------------------------------------------------------===// @@ -1689,6 +1722,7 @@ void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void AffineLoadOp::build(Builder *builder, OperationState *result, AffineMap map, ArrayRef operands) { + assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); result->addOperands(operands); if (map) result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); @@ -1697,17 +1731,25 @@ void AffineLoadOp::build(Builder *builder, OperationState *result, } void AffineLoadOp::build(Builder *builder, OperationState *result, - Value *memref, ArrayRef indices) { + Value *memref, AffineMap map, + ArrayRef mapOperands) { + assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result->addOperands(memref); - result->addOperands(indices); + result->addOperands(mapOperands); + auto memrefType = memref->getType().cast(); + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + result->types.push_back(memrefType.getElementType()); +} + +void AffineLoadOp::build(Builder *builder, OperationState *result, + Value *memref, ArrayRef indices) { auto memrefType = memref->getType().cast(); auto rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. auto map = rank ? builder->getMultiDimIdentityMap(rank) : builder->getEmptyAffineMap(); - result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); - result->types.push_back(memrefType.getElementType()); + build(builder, result, memref, map, indices); } ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) { @@ -1733,7 +1775,7 @@ void AffineLoadOp::print(OpAsmPrinter *p) { *p << "affine.load " << *getMemRef() << '['; AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); if (mapAttr) { - SmallVector operands(getIndices()); + SmallVector operands(getMapOperands()); p->printAffineMapOfSSAIds(mapAttr, operands); } *p << ']'; @@ -1759,7 +1801,7 @@ LogicalResult AffineLoadOp::verify() { "expects the number of subscripts to be equal to memref rank"); } - for (auto *idx : getIndices()) { + for (auto *idx : getMapOperands()) { if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); if (!isValidAffineIndexOperand(idx)) @@ -1772,6 +1814,7 @@ void AffineLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load results.insert(getOperationName(), context); + results.insert>(context); } //===----------------------------------------------------------------------===// @@ -1779,27 +1822,26 @@ void AffineLoadOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// void AffineStoreOp::build(Builder *builder, OperationState *result, - Value *valueToStore, AffineMap map, - ArrayRef operands) { + Value *valueToStore, Value *memref, AffineMap map, + ArrayRef mapOperands) { + assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result->addOperands(valueToStore); - result->addOperands(operands); - if (map) - result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + result->addOperands(memref); + result->addOperands(mapOperands); + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); } +// Use identity map. void AffineStoreOp::build(Builder *builder, OperationState *result, Value *valueToStore, Value *memref, - ArrayRef operands) { - result->addOperands(valueToStore); - result->addOperands(memref); - result->addOperands(operands); + ArrayRef indices) { auto memrefType = memref->getType().cast(); auto rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. auto map = rank ? builder->getMultiDimIdentityMap(rank) : builder->getEmptyAffineMap(); - result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + build(builder, result, valueToStore, memref, map, indices); } ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) { @@ -1828,7 +1870,7 @@ void AffineStoreOp::print(OpAsmPrinter *p) { *p << ", " << *getMemRef() << '['; AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); if (mapAttr) { - SmallVector operands(getIndices()); + SmallVector operands(getMapOperands()); p->printAffineMapOfSSAIds(mapAttr, operands); } *p << ']'; @@ -1855,7 +1897,7 @@ LogicalResult AffineStoreOp::verify() { "expects the number of subscripts to be equal to memref rank"); } - for (auto *idx : getIndices()) { + for (auto *idx : getMapOperands()) { if (!idx->getType().isIndex()) return emitOpError("index to store must have 'index' type"); if (!isValidAffineIndexOperand(idx)) @@ -1868,6 +1910,7 @@ void AffineStoreOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load results.insert(getOperationName(), context); + results.insert>(context); } #define GET_OP_CLASSES diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 2ed01a7..4111ba0 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -403,7 +403,7 @@ public: virtual PatternMatchResult matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineLoadOp'. - SmallVector indices(op.getIndices()); + SmallVector indices(op.getMapOperands()); auto maybeExpandedMap = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) @@ -425,7 +425,7 @@ public: virtual PatternMatchResult matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineStoreOp'. - SmallVector indices(op.getIndices()); + SmallVector indices(op.getMapOperands()); auto maybeExpandedMap = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 606cdb7..a54b05e 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -814,14 +814,15 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, // as needed by various targets. if (auto load = dyn_cast(opInst)) { OpBuilder b(opInst); - SmallVector mapOperands(load.getIndices()); + SmallVector mapOperands(load.getMapOperands()); SmallVector indices; indices.reserve(load.getMemRefType().getRank()); if (load.getAffineMap() != b.getMultiDimIdentityMap(load.getMemRefType().getRank())) { computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices); } else { - indices.append(load.getIndices().begin(), load.getIndices().end()); + indices.append(load.getMapOperands().begin(), + load.getMapOperands().end()); } auto permutationMap = makePermutationMap(opInst, indices, state->strategy->loopToVectorDim); @@ -1038,7 +1039,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, auto *value = store.getValueToStore(); auto *vectorValue = vectorizeOperand(value, opInst, state); - SmallVector mapOperands(store.getIndices()); + SmallVector mapOperands(store.getMapOperands()); SmallVector indices; indices.reserve(store.getMemRefType().getRank()); if (store.getAffineMap() != @@ -1046,7 +1047,8 @@ static Operation *vectorizeOneOperation(Operation *opInst, computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands, indices); } else { - indices.append(store.getIndices().begin(), store.getIndices().end()); + indices.append(store.getMapOperands().begin(), + store.getMapOperands().end()); } auto permutationMap = diff --git a/mlir/test/AffineOps/canonicalize.mlir b/mlir/test/AffineOps/canonicalize.mlir index a5768ad..6d84913 100644 --- a/mlir/test/AffineOps/canonicalize.mlir +++ b/mlir/test/AffineOps/canonicalize.mlir @@ -424,6 +424,7 @@ func @fold_empty_loop() { } return } +// CHECK: return // ----- @@ -476,3 +477,29 @@ func @canonicalize_bounds(%M : index, %N : index) { } return } + +// ----- + +// Compose maps into affine load and store ops. + +// CHECK-DAG: #map{{[0-9]+}} = (d0) -> (d0 + 1) + +// CHECK-LABEL: @compose_into_affine_load_store +func @compose_into_affine_load_store(%A : memref<1024xf32>, %u : index) { + %cf1 = constant 1.0 : f32 + // CHECK: affine.for %[[IV:.*]] = 0 to 1024 + affine.for %i = 0 to 1024 { + // Make sure the unused operand (%u below) gets dropped as well. + %idx = affine.apply (d0, d1) -> (d0 + 1) (%i, %u) + affine.load %A[%idx] : memref<1024xf32> + affine.store %cf1, %A[%idx] : memref<1024xf32> + // CHECK-NEXT: affine.load %{{.*}}[%[[IV]] + 1] + // CHECK-NEXT: affine.store %cst, %{{.*}}[%[[IV]] + 1] + + // Map remains the same, but operand changes on composition. + %copy = affine.apply (d0) -> (d0) (%i) + affine.load %A[%copy] : memref<1024xf32> + // CHECK-NEXT: affine.load %{{.*}}[%[[IV]]] + } + return +} diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index b954b69..cecd666 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -256,13 +256,12 @@ func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { // CHECK-LABEL: func @memref_cast_folding func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 { - // CHECK-NOT: memref_cast %1 = memref_cast %arg0 : memref<4xf32> to memref + // CHECK-NEXT: %c0 = constant 0 : index %c0 = constant 0 : index - // CHECK-NOT: dim %dim = dim %1, 0 : memref - // CHECK: affine.load %arg0[%c4 - 1] + // CHECK-NEXT: affine.load %arg0[3] affine.load %1[%dim - 1] : memref // CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32> -- 2.7.4