Refactor various canonicalization patterns as in-place folds.
authorRiver Riddle <riverriddle@google.com>
Fri, 13 Dec 2019 22:52:39 +0000 (14:52 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 14 Dec 2019 01:19:02 +0000 (17:19 -0800)
This is more efficient, and allows for these to fire in more situations: e.g. createOrFold, DialectConversion, etc.

PiperOrigin-RevId: 285476837

mlir/include/mlir/Dialect/AffineOps/AffineOps.h
mlir/include/mlir/Dialect/AffineOps/AffineOps.td
mlir/include/mlir/Dialect/QuantOps/QuantOps.td
mlir/include/mlir/Dialect/StandardOps/Ops.h
mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/lib/Dialect/AffineOps/AffineOps.cpp
mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp

index 835ac24..8268f81 100644 (file)
@@ -295,8 +295,8 @@ public:
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
   void print(OpAsmPrinter &p);
   LogicalResult verify();
-  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                          MLIRContext *context);
+  LogicalResult fold(ArrayRef<Attribute> cstOperands,
+                     SmallVectorImpl<OpFoldResult> &results);
 
   /// Returns true if this DMA operation is strided, returns false otherwise.
   bool isStrided() {
@@ -380,8 +380,8 @@ public:
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
   void print(OpAsmPrinter &p);
   LogicalResult verify();
-  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                          MLIRContext *context);
+  LogicalResult fold(ArrayRef<Attribute> cstOperands,
+                     SmallVectorImpl<OpFoldResult> &results);
 };
 
 /// The "affine.load" op reads an element from a memref, where the index
@@ -450,6 +450,7 @@ public:
   LogicalResult verify();
   static void getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context);
+  OpFoldResult fold(ArrayRef<Attribute> operands);
 };
 
 /// The "affine.store" op writes an element to a memref, where the index
@@ -520,6 +521,8 @@ public:
   LogicalResult verify();
   static void getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context);
+  LogicalResult fold(ArrayRef<Attribute> cstOperands,
+                     SmallVectorImpl<OpFoldResult> &results);
 };
 
 /// Returns true if the given Value can be used as a dimension id.
index 4d40604..cea44b8 100644 (file)
@@ -177,12 +177,13 @@ def AffineForOp : Affine_Op<"for",
     /// Sets the upper bound to the given constant value.
     void setConstantUpperBound(int64_t value);
 
-    /// Returns true if both the lower and upper bound have the same operand 
+    /// Returns true if both the lower and upper bound have the same operand
     /// lists (same operands in the same order).
     bool matchingBoundOperandList();
   }];
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
@@ -239,7 +240,7 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
     }
   }];
 
-  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def AffineMinOp : Affine_Op<"min"> {
index 85d5cd2..072715d 100644 (file)
@@ -93,7 +93,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> {
 def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
   let arguments = (ins quant_RealOrStorageValueType:$arg);
   let results = (outs quant_RealOrStorageValueType);
-  let hasCanonicalizer = 0b1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
index c7c8714..fcf16c0 100644 (file)
@@ -268,8 +268,8 @@ public:
   void print(OpAsmPrinter &p);
   LogicalResult verify();
 
-  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                          MLIRContext *context);
+  LogicalResult fold(ArrayRef<Attribute> cstOperands,
+                     SmallVectorImpl<OpFoldResult> &results);
 
   bool isStrided() {
     return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
@@ -331,8 +331,8 @@ public:
 
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
   void print(OpAsmPrinter &p);
-  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                          MLIRContext *context);
+  LogicalResult fold(ArrayRef<Attribute> cstOperands,
+                     SmallVectorImpl<OpFoldResult> &results);
 };
 
 /// Prints dimension and symbol list.
index 8e21a8b..553a612 100644 (file)
@@ -659,6 +659,7 @@ def DeallocOp : Std_Op<"dealloc"> {
   let arguments = (ins AnyMemRef:$memref);
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def DimOp : Std_Op<"dim", [NoSideEffect]> {
@@ -691,7 +692,6 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
   }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 def DivFOp : FloatArithmeticOp<"divf"> {
@@ -834,7 +834,7 @@ def LoadOp : Std_Op<"load"> {
     operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
   }];
 
-  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def LogOp : FloatUnaryOp<"log"> {
@@ -1137,7 +1137,7 @@ def StoreOp : Std_Op<"store"> {
       }
   }];
 
-  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def SubFOp : FloatArithmeticOp<"subf"> {
index 22d4ec1..e58f6f8 100644 (file)
@@ -814,33 +814,20 @@ void AffineApplyOp::getCanonicalizationPatterns(
 // Common canonicalization pattern support logic
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// This is a common class used for patterns of the form
 /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
 /// into the root operation directly.
-struct MemRefCastFolder : public RewritePattern {
-  /// The rootOpName is the name of the root operation to match against.
-  MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
-      : RewritePattern(rootOpName, 1, context) {}
-
-  PatternMatchResult match(Operation *op) const override {
-    for (auto *operand : op->getOperands())
-      if (matchPattern(operand, m_Op<MemRefCastOp>()))
-        return matchSuccess();
-
-    return matchFailure();
-  }
-
-  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
-    for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
-      if (auto *memref = op->getOperand(i)->getDefiningOp())
-        if (auto cast = dyn_cast<MemRefCastOp>(memref))
-          op->setOperand(i, cast.getOperand());
-    rewriter.updatedRootInPlace(op);
+static LogicalResult foldMemRefCast(Operation *op) {
+  bool folded = false;
+  for (OpOperand &operand : op->getOpOperands()) {
+    auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
+    if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+      operand.set(cast.getOperand());
+      folded = true;
+    }
   }
-};
-
-} // end anonymous namespace.
+  return success(folded);
+}
 
 //===----------------------------------------------------------------------===//
 // AffineDmaStartOp
@@ -1005,10 +992,10 @@ LogicalResult AffineDmaStartOp::verify() {
   return success();
 }
 
-void AffineDmaStartOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
+LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+                                     SmallVectorImpl<OpFoldResult> &results) {
   /// dma_start(memrefcast) -> dma_start
-  results.insert<MemRefCastFolder>(getOperationName(), context);
+  return foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1081,10 +1068,10 @@ LogicalResult AffineDmaWaitOp::verify() {
   return success();
 }
 
-void AffineDmaWaitOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
+LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+                                    SmallVectorImpl<OpFoldResult> &results) {
   /// dma_wait(memrefcast) -> dma_wait
-  results.insert<MemRefCastFolder>(getOperationName(), context);
+  return foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1255,7 +1242,8 @@ static ParseResult parseBound(bool isLower, OperationState &result,
       "expected valid affine map representation for loop bounds");
 }
 
-ParseResult parseAffineForOp(OpAsmParser &parser, OperationState &result) {
+static ParseResult parseAffineForOp(OpAsmParser &parser,
+                                    OperationState &result) {
   auto &builder = parser.getBuilder();
   OpAsmParser::OperandType inductionVariable;
   // Parse the induction variable followed by '='.
@@ -1344,7 +1332,7 @@ static void printBound(AffineMapAttr boundMap,
                         map.getNumDims(), p);
 }
 
-void print(OpAsmPrinter &p, AffineForOp op) {
+static void print(OpAsmPrinter &p, AffineForOp op) {
   p << "affine.for ";
   p.printOperand(op.getBody()->getArgument(0));
   p << " = ";
@@ -1363,115 +1351,102 @@ void print(OpAsmPrinter &p, AffineForOp op) {
                                            op.getStepAttrName()});
 }
 
-namespace {
-/// This is a pattern to fold trivially empty loops.
-struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
-  using OpRewritePattern<AffineForOp>::OpRewritePattern;
+/// Fold the constant bounds of a loop.
+static LogicalResult foldLoopBounds(AffineForOp forOp) {
+  auto foldLowerOrUpperBound = [&forOp](bool lower) {
+    // Check to see if each of the operands is the result of a constant.  If
+    // so, get the value.  If not, ignore it.
+    SmallVector<Attribute, 8> operandConstants;
+    auto boundOperands =
+        lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
+    for (auto *operand : boundOperands) {
+      Attribute operandCst;
+      matchPattern(operand, m_Constant(&operandCst));
+      operandConstants.push_back(operandCst);
+    }
 
-  PatternMatchResult matchAndRewrite(AffineForOp forOp,
-                                     PatternRewriter &rewriter) const override {
-    // Check that the body only contains a terminator.
-    if (!has_single_element(*forOp.getBody()))
-      return matchFailure();
-    rewriter.eraseOp(forOp);
-    return matchSuccess();
-  }
-};
+    AffineMap boundMap =
+        lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
+    assert(boundMap.getNumResults() >= 1 &&
+           "bound maps should have at least one result");
+    SmallVector<Attribute, 4> foldedResults;
+    if (failed(boundMap.constantFold(operandConstants, foldedResults)))
+      return failure();
 
-/// This is a pattern to fold constant loop bounds.
-struct AffineForOpBoundFolder : public OpRewritePattern<AffineForOp> {
-  using OpRewritePattern<AffineForOp>::OpRewritePattern;
+    // Compute the max or min as applicable over the results.
+    assert(!foldedResults.empty() && "bounds should have at least one result");
+    auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
+    for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
+      auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
+      maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
+                       : llvm::APIntOps::smin(maxOrMin, foldedResult);
+    }
+    lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
+          : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
+    return success();
+  };
 
-  PatternMatchResult matchAndRewrite(AffineForOp forOp,
-                                     PatternRewriter &rewriter) const override {
-    auto foldLowerOrUpperBound = [&forOp](bool lower) {
-      // Check to see if each of the operands is the result of a constant.  If
-      // so, get the value.  If not, ignore it.
-      SmallVector<Attribute, 8> operandConstants;
-      auto boundOperands =
-          lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
-      for (auto *operand : boundOperands) {
-        Attribute operandCst;
-        matchPattern(operand, m_Constant(&operandCst));
-        operandConstants.push_back(operandCst);
-      }
+  // Try to fold the lower bound.
+  bool folded = false;
+  if (!forOp.hasConstantLowerBound())
+    folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
 
-      AffineMap boundMap =
-          lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
-      assert(boundMap.getNumResults() >= 1 &&
-             "bound maps should have at least one result");
-      SmallVector<Attribute, 4> foldedResults;
-      if (failed(boundMap.constantFold(operandConstants, foldedResults)))
-        return failure();
-
-      // Compute the max or min as applicable over the results.
-      assert(!foldedResults.empty() &&
-             "bounds should have at least one result");
-      auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
-      for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
-        auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
-        maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
-                         : llvm::APIntOps::smin(maxOrMin, foldedResult);
-      }
-      lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
-            : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
-      return success();
-    };
-
-    // Try to fold the lower bound.
-    bool folded = false;
-    if (!forOp.hasConstantLowerBound())
-      folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
-
-    // Try to fold the upper bound.
-    if (!forOp.hasConstantUpperBound())
-      folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
-
-    // If any of the bounds were folded we return success.
-    if (!folded)
-      return matchFailure();
-    rewriter.updatedRootInPlace(forOp);
-    return matchSuccess();
-  }
-};
+  // Try to fold the upper bound.
+  if (!forOp.hasConstantUpperBound())
+    folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
+  return success(folded);
+}
 
-// This is a pattern to canonicalize affine for op loop bounds.
-struct AffineForOpBoundCanonicalizer : public OpRewritePattern<AffineForOp> {
-  using OpRewritePattern<AffineForOp>::OpRewritePattern;
+/// Canonicalize the bounds of the given loop.
+static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
+  SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
+  SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
 
-  PatternMatchResult matchAndRewrite(AffineForOp forOp,
-                                     PatternRewriter &rewriter) const override {
-    SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
-    SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
+  auto lbMap = forOp.getLowerBoundMap();
+  auto ubMap = forOp.getUpperBoundMap();
+  auto prevLbMap = lbMap;
+  auto prevUbMap = ubMap;
 
-    auto lbMap = forOp.getLowerBoundMap();
-    auto ubMap = forOp.getUpperBoundMap();
-    auto prevLbMap = lbMap;
-    auto prevUbMap = ubMap;
+  canonicalizeMapAndOperands(&lbMap, &lbOperands);
+  canonicalizeMapAndOperands(&ubMap, &ubOperands);
 
-    canonicalizeMapAndOperands(&lbMap, &lbOperands);
-    canonicalizeMapAndOperands(&ubMap, &ubOperands);
+  // Any canonicalization change always leads to updated map(s).
+  if (lbMap == prevLbMap && ubMap == prevUbMap)
+    return failure();
 
-    // Any canonicalization change always leads to updated map(s).
-    if (lbMap == prevLbMap && ubMap == prevUbMap)
-      return matchFailure();
+  if (lbMap != prevLbMap)
+    forOp.setLowerBound(lbOperands, lbMap);
+  if (ubMap != prevUbMap)
+    forOp.setUpperBound(ubOperands, ubMap);
+  return success();
+}
 
-    if (lbMap != prevLbMap)
-      forOp.setLowerBound(lbOperands, lbMap);
-    if (ubMap != prevUbMap)
-      forOp.setUpperBound(ubOperands, ubMap);
+namespace {
+/// This is a pattern to fold trivially empty loops.
+struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
+  using OpRewritePattern<AffineForOp>::OpRewritePattern;
 
-    rewriter.updatedRootInPlace(forOp);
+  PatternMatchResult matchAndRewrite(AffineForOp forOp,
+                                     PatternRewriter &rewriter) const override {
+    // Check that the body only contains a terminator.
+    if (!has_single_element(*forOp.getBody()))
+      return matchFailure();
+    rewriter.eraseOp(forOp);
     return matchSuccess();
   }
 };
-
 } // end anonymous namespace
 
 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.insert<AffineForEmptyLoopFolder, AffineForOpBoundFolder,
-                 AffineForOpBoundCanonicalizer>(context);
+  results.insert<AffineForEmptyLoopFolder>(context);
+}
+
+LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
+                                SmallVectorImpl<OpFoldResult> &results) {
+  bool folded = succeeded(foldLoopBounds(*this));
+  folded |= succeeded(canonicalizeLoopBounds(*this));
+  return success(folded);
 }
 
 AffineBound AffineForOp::getLowerBound() {
@@ -1741,37 +1716,23 @@ void AffineIfOp::build(Builder *builder, OperationState &result, IntegerSet set,
     AffineIfOp::ensureTerminator(*elseRegion, *builder, result.location);
 }
 
-namespace {
-// This is a pattern to canonicalize an affine if op's conditional (integer
-// set + operands).
-struct AffineIfOpCanonicalizer : public OpRewritePattern<AffineIfOp> {
-  using OpRewritePattern<AffineIfOp>::OpRewritePattern;
+/// Canonicalize an affine if op's conditional (integer set + operands).
+LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
+                               SmallVectorImpl<OpFoldResult> &) {
+  auto set = getIntegerSet();
+  SmallVector<Value *, 4> operands(getOperands());
+  canonicalizeSetAndOperands(&set, &operands);
 
-  PatternMatchResult matchAndRewrite(AffineIfOp ifOp,
-                                     PatternRewriter &rewriter) const override {
-    auto set = ifOp.getIntegerSet();
-    SmallVector<Value *, 4> operands(ifOp.getOperands());
-
-    canonicalizeSetAndOperands(&set, &operands);
-
-    // Any canonicalization change always leads to either a reduction in the
-    // number of operands or a change in the number of symbolic operands
-    // (promotion of dims to symbols).
-    if (operands.size() < ifOp.getIntegerSet().getNumInputs() ||
-        set.getNumSymbols() > ifOp.getIntegerSet().getNumSymbols()) {
-      ifOp.setConditional(set, operands);
-      rewriter.updatedRootInPlace(ifOp);
-      return matchSuccess();
-    }
-
-    return matchFailure();
+  // Any canonicalization change always leads to either a reduction in the
+  // number of operands or a change in the number of symbolic operands
+  // (promotion of dims to symbols).
+  if (operands.size() < getIntegerSet().getNumInputs() ||
+      set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
+    setConditional(set, operands);
+    return success();
   }
-};
-} // end anonymous namespace
 
-void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                             MLIRContext *context) {
-  results.insert<AffineIfOpCanonicalizer>(context);
+  return failure();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1866,11 +1827,16 @@ LogicalResult AffineLoadOp::verify() {
 
 void AffineLoadOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  /// load(memrefcast) -> load
-  results.insert<MemRefCastFolder>(getOperationName(), context);
   results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
 }
 
+OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
+  /// load(memrefcast) -> load
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // AffineStoreOp
 //===----------------------------------------------------------------------===//
@@ -1959,11 +1925,15 @@ LogicalResult AffineStoreOp::verify() {
 
 void AffineStoreOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  /// load(memrefcast) -> load
-  results.insert<MemRefCastFolder>(getOperationName(), context);
   results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
 }
 
+LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
+                                  SmallVectorImpl<OpFoldResult> &results) {
+  /// store(memrefcast) -> store
+  return foldMemRefCast(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // AffineMinOp
 //===----------------------------------------------------------------------===//
index b618ac0..51f1994 100644 (file)
@@ -32,38 +32,6 @@ using namespace mlir;
 using namespace mlir::quant;
 using namespace mlir::quant::detail;
 
-#define GET_OP_CLASSES
-#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
-
-namespace {
-
-/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
-/// value of x if the casts invert each other.
-class RemoveRedundantStorageCastsRewrite
-    : public OpRewritePattern<StorageCastOp> {
-public:
-  using OpRewritePattern<StorageCastOp>::OpRewritePattern;
-
-  PatternMatchResult matchAndRewrite(StorageCastOp op,
-                                     PatternRewriter &rewriter) const override {
-    if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
-      return matchFailure();
-    auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
-    if (srcScastOp.arg()->getType() != op.getType())
-      return matchFailure();
-
-    rewriter.replaceOp(op, srcScastOp.arg());
-    return matchSuccess();
-  }
-};
-
-} // end anonymous namespace
-
-void StorageCastOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
-}
-
 QuantizationDialect::QuantizationDialect(MLIRContext *context)
     : Dialect(/*name=*/"quant", context) {
   addTypes<AnyQuantizedType, UniformQuantizedType,
@@ -73,3 +41,15 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
 #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
       >();
 }
+
+OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
+  /// Matches x -> [scast -> scast] -> y, replacing the second scast with the
+  /// value of x if the casts invert each other.
+  auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg()->getDefiningOp());
+  if (!srcScastOp || srcScastOp.arg()->getType() != getType())
+    return OpFoldResult();
+  return srcScastOp.arg();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
index 713546f..3189e42 100644 (file)
@@ -212,32 +212,20 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
 // Common canonicalization pattern support logic
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// This is a common class used for patterns of the form
 /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
 /// into the root operation directly.
-struct MemRefCastFolder : public RewritePattern {
-  /// The rootOpName is the name of the root operation to match against.
-  MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
-      : RewritePattern(rootOpName, 1, context) {}
-
-  PatternMatchResult match(Operation *op) const override {
-    for (auto *operand : op->getOperands())
-      if (matchPattern(operand, m_Op<MemRefCastOp>()))
-        return matchSuccess();
-
-    return matchFailure();
-  }
-
-  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
-    for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
-      if (auto *memref = op->getOperand(i)->getDefiningOp())
-        if (auto cast = dyn_cast<MemRefCastOp>(memref))
-          op->setOperand(i, cast.getOperand());
-    rewriter.updatedRootInPlace(op);
+static LogicalResult foldMemRefCast(Operation *op) {
+  bool folded = false;
+  for (OpOperand &operand : op->getOpOperands()) {
+    auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
+    if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+      operand.set(cast.getOperand());
+      folded = true;
+    }
   }
-};
-} // end anonymous namespace.
+  return success(folded);
+}
 
 //===----------------------------------------------------------------------===//
 // AddFOp
@@ -1232,11 +1220,15 @@ static LogicalResult verify(DeallocOp op) {
 
 void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
-  /// dealloc(memrefcast) -> dealloc
-  results.insert<MemRefCastFolder>(getOperationName(), context);
   results.insert<SimplifyDeadDealloc>(context);
 }
 
+LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
+                              SmallVectorImpl<OpFoldResult> &results) {
+  /// dealloc(memrefcast) -> dealloc
+  return foldMemRefCast(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // DimOp
 //===----------------------------------------------------------------------===//
@@ -1304,7 +1296,6 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
     return {};
 
   // The size at getIndex() is now a dynamic size of a memref.
-
   auto memref = memrefOrTensor()->getDefiningOp();
   if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
     return *(alloc.getDynamicSizes().begin() +
@@ -1321,13 +1312,11 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
       return *(sizes.begin() + getIndex());
   }
 
-  return {};
-}
-
-void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                        MLIRContext *context) {
   /// dim(memrefcast) -> dim
-  results.insert<MemRefCastFolder>(getOperationName(), context);
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+
+  return {};
 }
 
 //===----------------------------------------------------------------------===//
@@ -1507,10 +1496,10 @@ LogicalResult DmaStartOp::verify() {
   return success();
 }
 
-void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                             MLIRContext *context) {
+LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+                               SmallVectorImpl<OpFoldResult> &results) {
   /// dma_start(memrefcast) -> dma_start
-  results.insert<MemRefCastFolder>(getOperationName(), context);
+  return foldMemRefCast(*this);
 }
 
 // ---------------------------------------------------------------------------
@@ -1565,10 +1554,10 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                            MLIRContext *context) {
+LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+                              SmallVectorImpl<OpFoldResult> &results) {
   /// dma_wait(memrefcast) -> dma_wait
-  results.insert<MemRefCastFolder>(getOperationName(), context);
+  return foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1688,10 +1677,11 @@ static LogicalResult verify(LoadOp op) {
   return success();
 }
 
-void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                         MLIRContext *context) {
+OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
   /// load(memrefcast) -> load
-  results.insert<MemRefCastFolder>(getOperationName(), context);
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return OpFoldResult();
 }
 
 //===----------------------------------------------------------------------===//
@@ -2092,10 +2082,10 @@ static LogicalResult verify(StoreOp op) {
   return success();
 }
 
-void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                          MLIRContext *context) {
+LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
+                            SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
-  results.insert<MemRefCastFolder>(getOperationName(), context);
+  return foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//