Add a generic pattern matcher for matching constant values produced by an operation...
authorRiver Riddle <riverriddle@google.com>
Tue, 19 Feb 2019 17:33:11 +0000 (09:33 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:31:56 +0000 (16:31 -0700)
PiperOrigin-RevId: 234616691

mlir/include/mlir/IR/Matchers.h
mlir/lib/AffineOps/AffineOps.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index d162a6aff467ca61eb41604c50229c4a0800feee..7de84b58d40d2b246e1a01851068927e7970fd5f 100644 (file)
@@ -60,6 +60,27 @@ struct attr_value_binder {
   }
 };
 
+/// The matcher that matches a constant foldable operation that has no operands
+/// and produces a single result.
+struct constant_op_binder {
+  Attribute *bind_value;
+
+  /// Creates a matcher instance that binds the constant attribute value to
+  /// bind_value if match succeeds.
+  constant_op_binder(Attribute *bind_value) : bind_value(bind_value) {}
+
+  bool match(Instruction *op) {
+    if (op->getNumOperands() > 0 || op->getNumResults() != 1)
+      return false;
+    SmallVector<Attribute, 1> foldedAttr;
+    if (!op->constantFold(/*operands=*/llvm::None, foldedAttr)) {
+      *bind_value = foldedAttr.front();
+      return true;
+    }
+    return false;
+  }
+};
+
 /// The matcher that matches a constant scalar / vector splat / tensor splat
 /// integer operation and binds the constant integer value.
 struct constant_int_op_binder {
@@ -69,18 +90,18 @@ struct constant_int_op_binder {
   constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
 
   bool match(Instruction *op) {
-    if (auto constOp = op->dyn_cast<ConstantOp>()) {
-      auto type = constOp->getResult()->getType();
-      auto attr = constOp->getAttr("value");
+    Attribute attr;
+    if (!constant_op_binder(&attr).match(op))
+      return false;
+    auto type = op->getResult(0)->getType();
 
-      if (type.isa<IntegerType>()) {
-        return attr_value_binder<IntegerAttr>(bind_value).match(attr);
-      }
-      if (type.isa<VectorOrTensorType>()) {
-        if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
-          return attr_value_binder<IntegerAttr>(bind_value)
-              .match(splatAttr.getValue());
-        }
+    if (type.isa<IntegerType>()) {
+      return attr_value_binder<IntegerAttr>(bind_value).match(attr);
+    }
+    if (type.isa<VectorOrTensorType>()) {
+      if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+        return attr_value_binder<IntegerAttr>(bind_value)
+            .match(splatAttr.getValue());
       }
     }
     return false;
@@ -118,13 +139,19 @@ inline detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
   return detail::op_matcher<ConstantIndexOp>();
 }
 
-/// Matches a ConstantOp holding a scalar/vector/tensor integer (splat) and
+/// Matches a constant holding a scalar/vector/tensor integer (splat) and
 /// writes the integer value to bind_value.
 inline detail::constant_int_op_binder
 m_ConstantInt(IntegerAttr::ValueType *bind_value) {
   return detail::constant_int_op_binder(bind_value);
 }
 
+/// Matches a value from a constant foldable operation and writes the value to
+/// bind_value.
+inline detail::constant_op_binder m_Constant(Attribute *bind_value) {
+  return detail::constant_op_binder(bind_value);
+}
+
 /// Matches a constant scalar / vector splat / tensor splat integer one.
 inline detail::constant_int_value_matcher<1> m_One() {
   return detail::constant_int_value_matcher<1>();
index 1859a640a455bc4e256bbf5e13d4ca5a34236b1c..0b82df271f5c418291b5d69267aeb2f373ef0970 100644 (file)
@@ -21,6 +21,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/StandardOps/StandardOps.h"
@@ -825,11 +826,9 @@ struct AffineForLoopBoundFolder : public RewritePattern {
       SmallVector<Attribute, 8> operandConstants;
       auto boundOperands = lower ? forOp->getLowerBoundOperands()
                                  : forOp->getUpperBoundOperands();
-      for (const auto *operand : boundOperands) {
+      for (auto *operand : boundOperands) {
         Attribute operandCst;
-        if (auto *operandOp = operand->getDefiningInst())
-          if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
-            operandCst = operandConstantOp->getValue();
+        matchPattern(operand, m_Constant(&operandCst));
         operandConstants.push_back(operandCst);
       }
 
index 671d72a461c496d41202755823d3b669cac6556f..71de1ef1830919b3cdca8dbf58af4906a61fa9f5 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/ADT/DenseMap.h"
 using namespace mlir;
@@ -200,15 +201,9 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
 
     // Check to see if any operands to the instruction is constant and whether
     // the operation knows how to constant fold itself.
-    operandConstants.clear();
-    for (auto *operand : op->getOperands()) {
-      Attribute operandCst;
-      if (auto *operandOp = operand->getDefiningInst()) {
-        if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
-          operandCst = operandConstantOp->getValue();
-      }
-      operandConstants.push_back(operandCst);
-    }
+    operandConstants.assign(op->getNumOperands(), Attribute());
+    for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+      matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
 
     // If this is a commutative binary operation with a constant on the left
     // side move it to the right side.