Unify the 'constantFold' and 'fold' hooks on an operation into just 'fold'. This...
authorRiver Riddle <riverriddle@google.com>
Thu, 16 May 2019 19:51:45 +0000 (12:51 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:44:24 +0000 (13:44 -0700)
--

PiperOrigin-RevId: 248582024

21 files changed:
mlir/g3doc/OpDefinitions.md
mlir/g3doc/QuickstartRewrites.md
mlir/include/mlir/AffineOps/AffineOps.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Matchers.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/StandardOps/Ops.h
mlir/include/mlir/StandardOps/Ops.td
mlir/include/mlir/Transforms/FoldUtils.h [moved from mlir/include/mlir/Transforms/ConstantFoldUtils.h with 65% similarity]
mlir/lib/AffineOps/AffineOps.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/StandardOps/Ops.cpp
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/TestConstantFold.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp [moved from mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp with 75% similarity]
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/mlir-tblgen/op-decl.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index 2b2e5af..5e79bd3 100644 (file)
@@ -371,11 +371,6 @@ This boolean field indicate whether canonicalization patterns have been defined
 for this operation. If it is `1`, then `::getCanonicalizationPatterns()` should
 be defined.
 
-### `hasConstantFolder`
-
-This boolean field indicate whether constant folding rules have been defined
-for this operation. If it is `1`, then `::constantFold()` should be defined.
-
 ### `hasFolder`
 
 This boolean field indicate whether general folding rules have been defined
index db440b2..19045f6 100644 (file)
@@ -74,11 +74,10 @@ in either way.
 Operations can also have custom parser, printer, builder, verifier, constant
 folder, or canonicalizer. These require specifying additional C++ methods to
 invoke for additional functionality. For example, if an operation is marked to
-have a constant folder, the constant folder also needs to be added, e.g.,:
+have a folder, the constant folder also needs to be added, e.g.,:
 
 ```c++
-Attribute SpecificOp::constantFold(ArrayRef<Attribute> operands,
-                                   MLIRContext *context) {
+OpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) {
   if (unable_to_fold)
     return {};
   ....
index a164807..d4dd214 100644 (file)
@@ -83,7 +83,7 @@ public:
   static ParseResult parse(OpAsmParser *parser, OperationState *result);
   void print(OpAsmPrinter *p);
   LogicalResult verify();
-  Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
+  OpFoldResult fold(ArrayRef<Attribute> operands);
 
   static void getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context);
index a8e45df..6c2286e 100644 (file)
@@ -761,7 +761,7 @@ private:
 
 namespace llvm {
 
-// Attribute hash just like pointers
+// Attribute hash just like pointers.
 template <> struct DenseMapInfo<mlir::Attribute> {
   static mlir::Attribute getEmptyKey() {
     auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
@@ -779,6 +779,18 @@ template <> struct DenseMapInfo<mlir::Attribute> {
   }
 };
 
+/// Allow LLVM to steal the low bits of Attributes.
+template <> struct PointerLikeTypeTraits<mlir::Attribute> {
+public:
+  static inline void *getAsVoidPointer(mlir::Attribute attr) {
+    return const_cast<void *>(attr.getAsOpaquePointer());
+  }
+  static inline mlir::Attribute getFromVoidPointer(void *ptr) {
+    return mlir::Attribute::getFromOpaquePointer(ptr);
+  }
+  enum { NumLowBitsAvailable = 3 };
+};
+
 } // namespace llvm
 
 #endif
index 9a7cb2d..3e70d7c 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef MLIR_MATCHERS_H
 #define MLIR_MATCHERS_H
 
-#include "mlir/IR/Operation.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
 #include <type_traits>
 
@@ -73,9 +73,9 @@ struct constant_op_binder {
     if (!op->hasNoSideEffect())
       return false;
 
-    SmallVector<Attribute, 1> foldedAttr;
-    if (succeeded(op->constantFold(/*operands=*/llvm::None, foldedAttr))) {
-      *bind_value = foldedAttr.front();
+    SmallVector<OpFoldResult, 1> foldedOp;
+    if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
+      *bind_value = foldedOp.front().dyn_cast<Attribute>();
       return true;
     }
     return false;
index ecd61db..440bf91 100644 (file)
@@ -925,9 +925,6 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
   // and C++ implementations.
   bit hasCanonicalizer = 0;
 
-  // Whether this op has a constant folder.
-  bit hasConstantFolder = 0;
-
   // Whether this op has a folder.
   bit hasFolder = 0;
 
index 788c01a..472441e 100644 (file)
@@ -170,47 +170,28 @@ inline bool operator!=(OpState lhs, OpState rhs) {
   return lhs.getOperation() != rhs.getOperation();
 }
 
-/// This template defines the constantFoldHook and foldHook as used by
-/// AbstractOperation.
+/// This class represents a single result from folding an operation.
+class OpFoldResult : public llvm::PointerUnion<Attribute, Value *> {
+  using llvm::PointerUnion<Attribute, Value *>::PointerUnion;
+};
+
+/// This template defines the foldHook as used by AbstractOperation.
 ///
-/// The default implementation uses a general constantFold/fold method that can
-/// be defined on custom ops which can return multiple results.
+/// The default implementation uses a general fold method that can be defined on
+/// custom ops which can return multiple results.
 template <typename ConcreteType, bool isSingleResult, typename = void>
 class FoldingHook {
 public:
   /// This is an implementation detail of the constant folder hook for
   /// AbstractOperation.
-  static LogicalResult constantFoldHook(Operation *op,
-                                        ArrayRef<Attribute> operands,
-                                        SmallVectorImpl<Attribute> &results) {
-    return cast<ConcreteType>(op).constantFold(operands, results,
-                                               op->getContext());
-  }
-
-  /// Op implementations can implement this hook.  It should attempt to constant
-  /// fold this operation with the specified constant operand values - the
-  /// elements in "operands" will correspond directly to the operands of the
-  /// operation, but may be null if non-constant.  If constant folding is
-  /// successful, this fills in the `results` vector.  If not, `results` is
-  /// unspecified.
-  ///
-  /// If not overridden, this fallback implementation always fails to fold.
-  ///
-  LogicalResult constantFold(ArrayRef<Attribute> operands,
-                             SmallVectorImpl<Attribute> &results,
-                             MLIRContext *context) {
-    return failure();
-  }
-
-  /// This is an implementation detail of the folder hook for AbstractOperation.
-  static LogicalResult foldHook(Operation *op,
-                                SmallVectorImpl<Value *> &results) {
-    return cast<ConcreteType>(op).fold(results);
+  static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+                                SmallVectorImpl<OpFoldResult> &results) {
+    return cast<ConcreteType>(op).fold(operands, results);
   }
 
   /// This hook implements a generalized folder for this operation.  Operations
   /// can implement this to provide simplifications rules that are applied by
-  /// the FuncBuilder::foldOrCreate API and the canonicalization pass.
+  /// the Builder::foldOrCreate API and the canonicalization pass.
   ///
   /// This is an intentionally limited interface - implementations of this hook
   /// can only perform the following changes to the operation:
@@ -225,23 +206,25 @@ public:
   ///     instead.
   ///
   /// This allows expression of some simple in-place canonicalizations (e.g.
-  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
-  /// not allow for canonicalizations that need to introduce new operations, not
-  /// even constants (e.g. "x-x -> 0" cannot be expressed).
+  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+  /// generalized constant folding.
   ///
   /// If not overridden, this fallback implementation always fails to fold.
   ///
-  LogicalResult fold(SmallVectorImpl<Value *> &results) { return failure(); }
+  LogicalResult fold(ArrayRef<Attribute> operands,
+                     SmallVectorImpl<OpFoldResult> &results) {
+    return failure();
+  }
 };
 
-/// This template specialization defines the constantFoldHook and foldHook as
-/// used by AbstractOperation for single-result operations.  This gives the hook
-/// a nicer signature that is easier to implement.
+/// This template specialization defines the foldHook as used by
+/// AbstractOperation for single-result operations.  This gives the hook a nicer
+/// signature that is easier to implement.
 template <typename ConcreteType, bool isSingleResult>
 class FoldingHook<ConcreteType, isSingleResult,
                   typename std::enable_if<isSingleResult>::type> {
 public:
-  /// If the operation returns a single value, then the Op can  be implicitly
+  /// If the operation returns a single value, then the Op can be implicitly
   /// converted to an Value*.  This yields the value of the only result.
   operator Value *() {
     return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
@@ -249,11 +232,9 @@ public:
 
   /// This is an implementation detail of the constant folder hook for
   /// AbstractOperation.
-  static LogicalResult constantFoldHook(Operation *op,
-                                        ArrayRef<Attribute> operands,
-                                        SmallVectorImpl<Attribute> &results) {
-    auto result =
-        cast<ConcreteType>(op).constantFold(operands, op->getContext());
+  static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+                                SmallVectorImpl<OpFoldResult> &results) {
+    auto result = cast<ConcreteType>(op).fold(operands);
     if (!result)
       return failure();
 
@@ -261,33 +242,9 @@ public:
     return success();
   }
 
-  /// Op implementations can implement this hook.  It should attempt to constant
-  /// fold this operation with the specified constant operand values - the
-  /// elements in "operands" will correspond directly to the operands of the
-  /// operation, but may be null if non-constant.  If constant folding is
-  /// successful, this returns a non-null attribute, otherwise it returns null
-  /// on failure.
-  ///
-  /// If not overridden, this fallback implementation always fails to fold.
-  ///
-  Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context) {
-    return nullptr;
-  }
-
-  /// This is an implementation detail of the folder hook for AbstractOperation.
-  static LogicalResult foldHook(Operation *op,
-                                SmallVectorImpl<Value *> &results) {
-    auto *result = cast<ConcreteType>(op).fold();
-    if (!result)
-      return failure();
-    if (result != op->getResult(0))
-      results.push_back(result);
-    return success();
-  }
-
   /// This hook implements a generalized folder for this operation.  Operations
   /// can implement this to provide simplifications rules that are applied by
-  /// the FuncBuilder::foldOrCreate API and the canonicalization pass.
+  /// the Builder::foldOrCreate API and the canonicalization pass.
   ///
   /// This is an intentionally limited interface - implementations of this hook
   /// can only perform the following changes to the operation:
@@ -301,13 +258,12 @@ public:
   ///     remove the operation and use that result instead.
   ///
   /// This allows expression of some simple in-place canonicalizations (e.g.
-  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
-  /// not allow for canonicalizations that need to introduce new operations, not
-  /// even constants (e.g. "x-x -> 0" cannot be expressed).
+  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+  /// generalized constant folding.
   ///
   /// If not overridden, this fallback implementation always fails to fold.
   ///
-  Value *fold() { return nullptr; }
+  OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; }
 };
 
 //===----------------------------------------------------------------------===//
index e71e8ed..a4af5de 100644 (file)
@@ -374,16 +374,12 @@ public:
     return getTerminatorStatus() == TerminatorStatus::NonTerminator;
   }
 
-  /// Attempt to constant fold this operation with the specified constant
-  /// operand values - the elements in "operands" will correspond directly to
-  /// the operands of the operation, but may be null if non-constant.  If
-  /// constant folding is successful, this fills in the `results` vector.  If
-  /// not, `results` is unspecified.
-  LogicalResult constantFold(ArrayRef<Attribute> operands,
-                             SmallVectorImpl<Attribute> &results);
-
-  /// Attempt to fold this operation using the Op's registered foldHook.
-  LogicalResult fold(SmallVectorImpl<Value *> &results);
+  /// Attempt to fold this operation with the specified constant operand values
+  /// - the elements in "operands" will correspond directly to the operands of
+  /// the operation, but may be null if non-constant. If folding is successful,
+  /// this fills in the `results` vector. If not, `results` is unspecified.
+  LogicalResult fold(ArrayRef<Attribute> operands,
+                     SmallVectorImpl<OpFoldResult> &results);
 
   //===--------------------------------------------------------------------===//
   // Operation Walkers
index 169012a..312cf14 100644 (file)
@@ -41,6 +41,7 @@ struct OperationState;
 class OpAsmParser;
 class OpAsmParserResult;
 class OpAsmPrinter;
+class OpFoldResult;
 class ParseResult;
 class Pattern;
 class Region;
@@ -95,14 +96,9 @@ public:
   /// success if everything is ok.
   LogicalResult (&verifyInvariants)(Operation *op);
 
-  /// This hook implements a constant folder for this operation.  It fills in
-  /// `results` on success.
-  LogicalResult (&constantFoldHook)(Operation *op, ArrayRef<Attribute> operands,
-                                    SmallVectorImpl<Attribute> &results);
-
   /// This hook implements a generalized folder for this operation.  Operations
   /// can implement this to provide simplifications rules that are applied by
-  /// the FuncBuilder::foldOrCreate API and the canonicalization pass.
+  /// the Builder::foldOrCreate API and the canonicalization pass.
   ///
   /// This is an intentionally limited interface - implementations of this hook
   /// can only perform the following changes to the operation:
@@ -117,10 +113,10 @@ public:
   ///     instead.
   ///
   /// This allows expression of some simple in-place canonicalizations (e.g.
-  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
-  /// not allow for canonicalizations that need to introduce new operations, not
-  /// even constants (e.g. "x-x -> 0" cannot be expressed).
-  LogicalResult (&foldHook)(Operation *op, SmallVectorImpl<Value *> &results);
+  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+  /// generalized constant folding.
+  LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
+                            SmallVectorImpl<OpFoldResult> &results);
 
   /// This hook returns any canonicalization pattern rewrites that the operation
   /// supports, for use by the canonicalization pass.
@@ -142,8 +138,8 @@ public:
   template <typename T> static AbstractOperation get(Dialect &dialect) {
     return AbstractOperation(
         T::getOperationName(), dialect, T::getOperationProperties(), T::classof,
-        T::parseAssembly, T::printAssembly, T::verifyInvariants,
-        T::constantFoldHook, T::foldHook, T::getCanonicalizationPatterns);
+        T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook,
+        T::getCanonicalizationPatterns);
   }
 
 private:
@@ -153,17 +149,13 @@ private:
       ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result),
       void (&printAssembly)(Operation *op, OpAsmPrinter *p),
       LogicalResult (&verifyInvariants)(Operation *op),
-      LogicalResult (&constantFoldHook)(Operation *op,
-                                        ArrayRef<Attribute> operands,
-                                        SmallVectorImpl<Attribute> &results),
-      LogicalResult (&foldHook)(Operation *op,
-                                SmallVectorImpl<Value *> &results),
+      LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
+                                SmallVectorImpl<OpFoldResult> &results),
       void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
                                           MLIRContext *context))
       : name(name), dialect(dialect), classof(classof),
         parseAssembly(parseAssembly), printAssembly(printAssembly),
-        verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
-        foldHook(foldHook),
+        verifyInvariants(verifyInvariants), foldHook(foldHook),
         getCanonicalizationPatterns(getCanonicalizationPatterns),
         opProperties(opProperties) {}
 
index 866012b..6db6fe0 100644 (file)
@@ -101,7 +101,7 @@ public:
   static ParseResult parse(OpAsmParser *parser, OperationState *result);
   void print(OpAsmPrinter *p);
   LogicalResult verify();
-  Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
+  OpFoldResult fold(ArrayRef<Attribute> operands);
 };
 
 /// The predicate indicates the type of the comparison to perform:
@@ -176,7 +176,7 @@ public:
   static ParseResult parse(OpAsmParser *parser, OperationState *result);
   void print(OpAsmPrinter *p);
   LogicalResult verify();
-  Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
+  OpFoldResult fold(ArrayRef<Attribute> operands);
 };
 
 /// The "cond_br" operation represents a conditional branch operation in a
@@ -600,7 +600,7 @@ public:
   Value *getTrueValue() { return getOperand(1); }
   Value *getFalseValue() { return getOperand(2); }
 
-  Value *fold();
+  OpFoldResult fold(ArrayRef<Attribute> operands);
 };
 
 /// The "store" op writes an element to a memref specified by an index list.
index ffb2150..e3b521a 100644 (file)
@@ -112,13 +112,12 @@ class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
 
 def AddFOp : FloatArithmeticOp<"addf"> {
   let summary = "floating point addition operation";
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
   let summary = "integer addition operation";
   let hasFolder = 1;
-  let hasConstantFolder = 1;
 }
 
 def AllocOp : Std_Op<"alloc"> {
@@ -163,7 +162,6 @@ def AllocOp : Std_Op<"alloc"> {
 
 def AndOp : IntArithmeticOp<"and", [Commutative]> {
   let summary = "integer binary and";
-  let hasConstantFolder = 1;
   let hasFolder = 1;
 }
 
@@ -288,7 +286,7 @@ def ConstantOp : Std_Op<"constant", [NoSideEffect]> {
     Attribute getValue() { return getAttr("value"); }
   }];
 
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def DeallocOp : Std_Op<"dealloc"> {
@@ -338,7 +336,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
     }
   }];
 
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def DivFOp : FloatArithmeticOp<"divf"> {
@@ -347,12 +345,12 @@ def DivFOp : FloatArithmeticOp<"divf"> {
 
 def DivISOp : IntArithmeticOp<"divis"> {
   let summary = "signed integer division operation";
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def DivIUOp : IntArithmeticOp<"diviu"> {
   let summary = "unsigned integer division operation";
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
@@ -389,7 +387,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
     }
   }];
 
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def MemRefCastOp : CastOp<"memref_cast"> {
@@ -426,18 +424,16 @@ def MemRefCastOp : CastOp<"memref_cast"> {
 
 def MulFOp : FloatArithmeticOp<"mulf"> {
   let summary = "foating point multiplication operation";
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
   let summary = "integer multiplication operation";
-  let hasConstantFolder = 1;
   let hasFolder = 1;
 }
 
 def OrOp : IntArithmeticOp<"or", [Commutative]> {
   let summary = "integer binary or";
-  let hasConstantFolder = 1;
   let hasFolder = 1;
 }
 
@@ -447,12 +443,12 @@ def RemFOp : FloatArithmeticOp<"remf"> {
 
 def RemISOp : IntArithmeticOp<"remis"> {
   let summary = "signed integer division remainder operation";
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def RemIUOp : IntArithmeticOp<"remiu"> {
   let summary = "unsigned integer division remainder operation";
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def ReturnOp : Std_Op<"return", [Terminator]> {
@@ -481,13 +477,12 @@ def ShlISOp : IntArithmeticOp<"shlis"> {
 
 def SubFOp : FloatArithmeticOp<"subf"> {
   let summary = "floating point subtraction operation";
-  let hasConstantFolder = 1;
+  let hasFolder = 1;
 }
 
 def SubIOp : IntArithmeticOp<"subi"> {
   let summary = "integer subtraction operation";
-  let hasConstantFolder = 1;
-  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def TensorCastOp : CastOp<"tensor_cast"> {
@@ -518,8 +513,6 @@ def TensorCastOp : CastOp<"tensor_cast"> {
 
 def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
   let summary = "integer binary xor";
-  let hasConstantFolder = 1;
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
similarity index 65%
rename from mlir/include/mlir/Transforms/ConstantFoldUtils.h
rename to mlir/include/mlir/Transforms/FoldUtils.h
index d2309a3..6264bd1 100644 (file)
@@ -1,4 +1,4 @@
-//===- ConstantFoldUtils.h - Constant Fold Utilities ------------*- C++ -*-===//
+//===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
 //
 // Copyright 2019 The MLIR Authors.
 //
 // limitations under the License.
 // =============================================================================
 //
-// This header file declares various constant fold utilities. These utilities
-// are intended to be used by passes to unify and simply their logic.
+// This header file declares various operation folding utilities. These
+// utilities are intended to be used by passes to unify and simply their logic.
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_TRANSFORMS_CONSTANT_UTILS_H
-#define MLIR_TRANSFORMS_CONSTANT_UTILS_H
+#ifndef MLIR_TRANSFORMS_FOLDUTILS_H
+#define MLIR_TRANSFORMS_FOLDUTILS_H
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Types.h"
@@ -32,13 +32,13 @@ namespace mlir {
 class Function;
 class Operation;
 
-/// A helper class for constant folding operations, and unifying duplicated
-/// constants along the way.
+/// A helper class for folding operations, and unifying duplicated constants
+/// generated along the way.
 ///
-/// To make sure constants' proper dominance of all their uses, constants are
+/// To make sure constants properly dominate all their uses, constants are
 /// moved to the beginning of the entry block of the function when tracked by
 /// this class.
-class ConstantFoldHelper {
+class FoldHelper {
 public:
   /// Constructs an instance for managing constants in the given function `f`.
   /// Constants tracked by this instance will be moved to the entry block of
@@ -47,32 +47,30 @@ public:
   /// This instance does not proactively walk the operations inside `f`;
   /// instead, users must invoke the following methods to manually handle each
   /// operation of interest.
-  ConstantFoldHelper(Function *f);
+  FoldHelper(Function *f);
 
-  /// Tries to perform constant folding on the given `op`, including unifying
-  /// deplicated constants. If successful, calls `preReplaceAction` (if
+  /// Tries to perform folding on the given `op`, including unifying
+  /// deduplicated constants. If successful, calls `preReplaceAction` (if
   /// provided) by passing in `op`, then replaces `op`'s uses with folded
-  /// constants, and returns true.
-  ///
-  /// Note: `op` will *not* be erased to avoid invalidating potential walkers in
-  /// the caller.
-  bool
-  tryToConstantFold(Operation *op,
-                    std::function<void(Operation *)> preReplaceAction = {});
+  /// results, and returns success. If the op was completely folded it is
+  /// erased.
+  LogicalResult
+  tryToFold(Operation *op,
+            std::function<void(Operation *)> preReplaceAction = {});
 
   /// Notifies that the given constant `op` should be remove from this
-  /// ConstantFoldHelper's internal bookkeeping.
+  /// FoldHelper's internal bookkeeping.
   ///
   /// Note: this method must be called if a constant op is to be deleted
-  /// externally to this ConstantFoldHelper. `op` must be a constant op.
+  /// externally to this FoldHelper. `op` must be a constant op.
   void notifyRemoval(Operation *op);
 
 private:
-  /// Tries to deduplicate the given constant and returns true if that can be
+  /// Tries to deduplicate the given constant and returns success if that can be
   /// done. This moves the given constant to the top of the entry block if it
   /// is first seen. If there is already an existing constant that is the same,
   /// this does *not* erases the given constant.
-  bool tryToUnify(Operation *op);
+  LogicalResult tryToUnify(Operation *op);
 
   /// Moves the given constant `op` to entry block to guarantee dominance.
   void moveConstantToEntryBlock(Operation *op);
@@ -86,4 +84,4 @@ private:
 
 } // end namespace mlir
 
-#endif // MLIR_TRANSFORMS_CONSTANT_UTILS_H
+#endif // MLIR_TRANSFORMS_FOLDUTILS_H
index 40069f6..130cb15 100644 (file)
@@ -201,12 +201,11 @@ bool AffineApplyOp::isValidSymbol() {
                       [](Value *op) { return mlir::isValidSymbol(op); });
 }
 
-Attribute AffineApplyOp::constantFold(ArrayRef<Attribute> operands,
-                                      MLIRContext *context) {
+OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
   auto map = getAffineMap();
   SmallVector<Attribute, 1> result;
   if (failed(map.constantFold(operands, result)))
-    return Attribute();
+    return {};
   return result[0];
 }
 
index 2a67f5a..b5a7a20 100644 (file)
@@ -509,40 +509,32 @@ auto Operation::getSuccessorOperands(unsigned index) -> operand_range {
                            succOperandIndex + getNumSuccessorOperands(index))};
 }
 
-/// Attempt to constant fold this operation with the specified constant
-/// operand values.  If successful, this fills in the results vector.  If not,
-/// results is unspecified.
-LogicalResult Operation::constantFold(ArrayRef<Attribute> operands,
-                                      SmallVectorImpl<Attribute> &results) {
-  if (auto *abstractOp = getAbstractOperation()) {
-    // If we have a registered operation definition matching this one, use it to
-    // try to constant fold the operation.
-    if (succeeded(abstractOp->constantFoldHook(this, operands, results)))
-      return success();
-
-    // Otherwise, fall back on the dialect hook to handle it.
-    return abstractOp->dialect.constantFoldHook(this, operands, results);
-  }
-
-  // If this operation hasn't been registered or doesn't have abstract
-  // operation, fall back to a dialect which matches the prefix.
-  auto opName = getName().getStringRef();
-  auto dialectPrefix = opName.split('.').first;
-  if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix))
-    return dialect->constantFoldHook(this, operands, results);
-
-  return failure();
-}
-
 /// Attempt to fold this operation using the Op's registered foldHook.
-LogicalResult Operation::fold(SmallVectorImpl<Value *> &results) {
-  if (auto *abstractOp = getAbstractOperation()) {
-    // If we have a registered operation definition matching this one, use it to
-    // try to constant fold the operation.
-    if (succeeded(abstractOp->foldHook(this, results)))
-      return success();
+LogicalResult Operation::fold(ArrayRef<Attribute> operands,
+                              SmallVectorImpl<OpFoldResult> &results) {
+  // If we have a registered operation definition matching this one, use it to
+  // try to constant fold the operation.
+  auto *abstractOp = getAbstractOperation();
+  if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results)))
+    return success();
+
+  // Otherwise, fall back on the dialect hook to handle it.
+  Dialect *dialect;
+  if (abstractOp) {
+    dialect = &abstractOp->dialect;
+  } else {
+    // If this operation hasn't been registered, lookup the parent dialect.
+    auto opName = getName().getStringRef();
+    auto dialectPrefix = opName.split('.').first;
+    if (!(dialect = getContext()->getRegisteredDialect(dialectPrefix)))
+      return failure();
   }
-  return failure();
+
+  SmallVector<Attribute, 8> constants;
+  if (failed(dialect->constantFoldHook(this, operands, constants)))
+    return failure();
+  results.assign(constants.begin(), constants.end());
+  return success();
 }
 
 /// Emit an error with the op name prefixed, like "'dim' op " which is
index 490f6e4..214d5a9 100644 (file)
@@ -196,8 +196,7 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
 // AddFOp
 //===----------------------------------------------------------------------===//
 
-Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
-                               MLIRContext *context) {
+OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
   return constFoldBinaryOp<FloatAttr>(
       operands, [](APFloat a, APFloat b) { return a + b; });
 }
@@ -206,18 +205,13 @@ Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
-                               MLIRContext *context) {
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a + b; });
-}
-
-Value *AddIOp::fold() {
+OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
   /// addi(x, 0) -> x
-  if (matchPattern(getOperand(1), m_Zero()))
-    return getOperand(0);
+  if (matchPattern(rhs(), m_Zero()))
+    return lhs();
 
-  return nullptr;
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a + b; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -770,8 +764,7 @@ static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
 }
 
 // Constant folding hook for comparisons.
-Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
-                               MLIRContext *context) {
+OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "cmpi takes two arguments");
 
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
@@ -780,7 +773,7 @@ Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
     return {};
 
   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
-  return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
+  return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
 }
 
 //===----------------------------------------------------------------------===//
@@ -967,8 +960,7 @@ static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
 }
 
 // Constant folding hook for comparisons.
-Attribute CmpFOp::constantFold(ArrayRef<Attribute> operands,
-                               MLIRContext *context) {
+OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "cmpf takes two arguments");
 
   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
@@ -980,7 +972,7 @@ Attribute CmpFOp::constantFold(ArrayRef<Attribute> operands,
     return {};
 
   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
-  return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
+  return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
 }
 
 //===----------------------------------------------------------------------===//
@@ -1179,8 +1171,7 @@ static LogicalResult verify(ConstantOp &op) {
       "requires a result type that aligns with the 'value' attribute");
 }
 
-Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
-                                   MLIRContext *context) {
+OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.empty() && "constant has no operands");
   return getValue();
 }
@@ -1337,8 +1328,7 @@ static LogicalResult verify(DimOp op) {
   return success();
 }
 
-Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
-                              MLIRContext *context) {
+OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
   // Constant fold dim when the size along the index referred to is a constant.
   auto opType = getOperand()->getType();
   int64_t indexSize = -1;
@@ -1348,19 +1338,17 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
     indexSize = memrefType.getShape()[getIndex()];
 
   if (indexSize >= 0)
-    return IntegerAttr::get(IndexType::get(context), indexSize);
+    return IntegerAttr::get(IndexType::get(getContext()), indexSize);
 
-  return nullptr;
+  return {};
 }
 
 //===----------------------------------------------------------------------===//
 // DivISOp
 //===----------------------------------------------------------------------===//
 
-Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
-                                MLIRContext *context) {
+OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "binary operation takes two operands");
-  (void)context;
 
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
@@ -1368,9 +1356,8 @@ Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
     return {};
 
   // Don't fold if it requires division by zero.
-  if (rhs.getValue().isNullValue()) {
+  if (rhs.getValue().isNullValue())
     return {};
-  }
 
   // Don't fold if it would overflow.
   bool overflow;
@@ -1382,10 +1369,8 @@ Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
 // DivIUOp
 //===----------------------------------------------------------------------===//
 
-Attribute DivIUOp::constantFold(ArrayRef<Attribute> operands,
-                                MLIRContext *context) {
+OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "binary operation takes two operands");
-  (void)context;
 
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
@@ -1675,14 +1660,13 @@ static LogicalResult verify(ExtractElementOp op) {
   return success();
 }
 
-Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
-                                         MLIRContext *context) {
+OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
   assert(!operands.empty() && "extract_element takes atleast one operand");
 
   // The aggregate operand must be a known constant.
   Attribute aggregate = operands.front();
   if (!aggregate)
-    return Attribute();
+    return {};
 
   // If this is a splat elements attribute, simply return the value. All of the
   // elements of a splat attribute are the same.
@@ -1693,14 +1677,14 @@ Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
   SmallVector<uint64_t, 8> indices;
   for (Attribute indice : llvm::drop_begin(operands, 1)) {
     if (!indice || !indice.isa<IntegerAttr>())
-      return Attribute();
+      return {};
     indices.push_back(indice.cast<IntegerAttr>().getInt());
   }
 
   // If this is an elements attribute, query the value at the given indices.
   if (auto elementsAttr = aggregate.dyn_cast<ElementsAttr>())
     return elementsAttr.getValue(indices);
-  return Attribute();
+  return {};
 }
 
 //===----------------------------------------------------------------------===//
@@ -1801,14 +1785,15 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
   return true;
 }
 
-Value *MemRefCastOp::fold() { return impl::foldCastOp(*this); }
+OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
+  return impl::foldCastOp(*this);
+}
 
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//
 
-Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
-                               MLIRContext *context) {
+OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
   return constFoldBinaryOp<FloatAttr>(
       operands, [](APFloat a, APFloat b) { return a * b; });
 }
@@ -1817,29 +1802,24 @@ Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
-                               MLIRContext *context) {
-  // TODO: Handle the overflow case.
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a * b; });
-}
-
-Value *MulIOp::fold() {
+OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
   /// muli(x, 0) -> 0
-  if (matchPattern(getOperand(1), m_Zero()))
-    return getOperand(1);
+  if (matchPattern(rhs(), m_Zero()))
+    return rhs();
   /// muli(x, 1) -> x
-  if (matchPattern(getOperand(1), m_One()))
+  if (matchPattern(rhs(), m_One()))
     return getOperand(0);
-  return nullptr;
+
+  // TODO: Handle the overflow case.
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a * b; });
 }
 
 //===----------------------------------------------------------------------===//
 // RemISOp
 //===----------------------------------------------------------------------===//
 
-Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
-                                MLIRContext *context) {
+OpFoldResult RemISOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "remis takes two operands");
 
   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
@@ -1852,9 +1832,8 @@ Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
                             APInt(rhs.getValue().getBitWidth(), 0));
 
   // Don't fold if it requires division by zero.
-  if (rhs.getValue().isNullValue()) {
+  if (rhs.getValue().isNullValue())
     return {};
-  }
 
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
   if (!lhs)
@@ -1867,8 +1846,7 @@ Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
 // RemIUOp
 //===----------------------------------------------------------------------===//
 
-Attribute RemIUOp::constantFold(ArrayRef<Attribute> operands,
-                                MLIRContext *context) {
+OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "remiu takes two operands");
 
   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
@@ -1881,9 +1859,8 @@ Attribute RemIUOp::constantFold(ArrayRef<Attribute> operands,
                             APInt(rhs.getValue().getBitWidth(), 0));
 
   // Don't fold if it requires division by zero.
-  if (rhs.getValue().isNullValue()) {
+  if (rhs.getValue().isNullValue())
     return {};
-  }
 
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
   if (!lhs)
@@ -1990,7 +1967,7 @@ LogicalResult SelectOp::verify() {
   return success();
 }
 
-Value *SelectOp::fold() {
+OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
   auto *condition = getCondition();
 
   // select true, %0, %1 => %0
@@ -2081,8 +2058,7 @@ void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 // SubFOp
 //===----------------------------------------------------------------------===//
 
-Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
-                               MLIRContext *context) {
+OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
   return constFoldBinaryOp<FloatAttr>(
       operands, [](APFloat a, APFloat b) { return a - b; });
 }
@@ -2091,48 +2067,20 @@ Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
-                               MLIRContext *context) {
+OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
+  // subi(x,x) -> 0
+  if (getOperand(0) == getOperand(1))
+    return Builder(getContext()).getZeroAttr(getType());
+
   return constFoldBinaryOp<IntegerAttr>(operands,
                                         [](APInt a, APInt b) { return a - b; });
 }
 
-namespace {
-/// subi(x,x) -> 0
-///
-struct SimplifyXMinusX : public RewritePattern {
-  SimplifyXMinusX(MLIRContext *context)
-      : RewritePattern(SubIOp::getOperationName(), 1, context) {}
-
-  PatternMatchResult matchAndRewrite(Operation *op,
-                                     PatternRewriter &rewriter) const override {
-    auto subi = cast<SubIOp>(op);
-    if (subi.getOperand(0) != subi.getOperand(1))
-      return matchFailure();
-
-    rewriter.replaceOpWithNewOp<ConstantOp>(
-        op, subi.getType(), rewriter.getZeroAttr(subi.getType()));
-    return matchSuccess();
-  }
-};
-} // end anonymous namespace.
-
-void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                         MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyXMinusX>(context));
-}
-
 //===----------------------------------------------------------------------===//
 // AndOp
 //===----------------------------------------------------------------------===//
 
-Attribute AndOp::constantFold(ArrayRef<Attribute> operands,
-                              MLIRContext *context) {
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a & b; });
-}
-
-Value *AndOp::fold() {
+OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
   /// and(x, 0) -> 0
   if (matchPattern(rhs(), m_Zero()))
     return rhs();
@@ -2140,20 +2088,15 @@ Value *AndOp::fold() {
   if (lhs() == rhs())
     return rhs();
 
-  return nullptr;
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a & b; });
 }
 
 //===----------------------------------------------------------------------===//
 // OrOp
 //===----------------------------------------------------------------------===//
 
-Attribute OrOp::constantFold(ArrayRef<Attribute> operands,
-                             MLIRContext *context) {
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a | b; });
-}
-
-Value *OrOp::fold() {
+OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
   /// or(x, 0) -> x
   if (matchPattern(rhs(), m_Zero()))
     return lhs();
@@ -2161,51 +2104,26 @@ Value *OrOp::fold() {
   if (lhs() == rhs())
     return rhs();
 
-  return nullptr;
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a | b; });
 }
 
 //===----------------------------------------------------------------------===//
 // XOrOp
 //===----------------------------------------------------------------------===//
 
-Attribute XOrOp::constantFold(ArrayRef<Attribute> operands,
-                              MLIRContext *context) {
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a ^ b; });
-}
-
-Value *XOrOp::fold() {
+OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
   /// xor(x, 0) -> x
   if (matchPattern(rhs(), m_Zero()))
     return lhs();
+  /// xor(x,x) -> 0
+  if (lhs() == rhs())
+    return Builder(getContext()).getZeroAttr(getType());
 
-  return nullptr;
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a ^ b; });
 }
 
-namespace {
-/// xor(x,x) -> 0
-///
-struct SimplifyXXOrX : public RewritePattern {
-  SimplifyXXOrX(MLIRContext *context)
-      : RewritePattern(XOrOp::getOperationName(), 1, context) {}
-
-  PatternMatchResult matchAndRewrite(Operation *op,
-                                     PatternRewriter &rewriter) const override {
-    auto xorOp = cast<XOrOp>(op);
-    if (xorOp.lhs() != xorOp.rhs())
-      return matchFailure();
-
-    rewriter.replaceOpWithNewOp<ConstantOp>(
-        op, xorOp.getType(), rewriter.getZeroAttr(xorOp.getType()));
-    return matchSuccess();
-  }
-};
-} // end anonymous namespace.
-
-void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                        MLIRContext *context) {
-  results.push_back(llvm::make_unique<SimplifyXXOrX>(context));
-}
 //===----------------------------------------------------------------------===//
 // TensorCastOp
 //===----------------------------------------------------------------------===//
@@ -2239,7 +2157,9 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) {
   return true;
 }
 
-Value *TensorCastOp::fold() { return impl::foldCastOp(*this); }
+OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
+  return impl::foldCastOp(*this);
+}
 
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
index c6b58d8..7d5b81e 100644 (file)
@@ -17,7 +17,7 @@ add_llvm_library(MLIRTransforms
   SimplifyAffineStructures.cpp
   StripDebugInfo.cpp
   TestConstantFold.cpp
-  Utils/ConstantFoldUtils.cpp
+  Utils/FoldUtils.cpp
   Utils/GreedyPatternRewriteDriver.cpp
   Utils/LoopUtils.cpp
   Utils/Utils.cpp
index ec1e971..1169607 100644 (file)
@@ -20,7 +20,7 @@
 #include "mlir/IR/Function.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/StandardOps/Ops.h"
-#include "mlir/Transforms/ConstantFoldUtils.h"
+#include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/Utils.h"
 
@@ -31,26 +31,22 @@ namespace {
 struct TestConstantFold : public FunctionPass<TestConstantFold> {
   // All constants in the function post folding.
   SmallVector<Operation *, 8> existingConstants;
-  // Operations that were folded and that need to be erased.
-  std::vector<Operation *> opsToErase;
 
-  void foldOperation(Operation *op, ConstantFoldHelper &helper);
+  void foldOperation(Operation *op, FoldHelper &helper);
   void runOnFunction() override;
 };
 } // end anonymous namespace
 
-void TestConstantFold::foldOperation(Operation *op,
-                                     ConstantFoldHelper &helper) {
+void TestConstantFold::foldOperation(Operation *op, FoldHelper &helper) {
   // Attempt to fold the specified operation, including handling unused or
   // duplicated constants.
-  if (helper.tryToConstantFold(op)) {
-    opsToErase.push_back(op);
-  }
+  if (succeeded(helper.tryToFold(op)))
+    return;
+
   // If this op is a constant that are used and cannot be de-duplicated,
   // remember it for cleanup later.
-  else if (auto constant = dyn_cast<ConstantOp>(op)) {
+  if (auto constant = dyn_cast<ConstantOp>(op))
     existingConstants.push_back(op);
-  }
 }
 
 // For now, we do a simple top-down pass over a function folding constants.  We
@@ -58,10 +54,9 @@ void TestConstantFold::foldOperation(Operation *op,
 // branches, or anything else fancy.
 void TestConstantFold::runOnFunction() {
   existingConstants.clear();
-  opsToErase.clear();
 
   auto &f = getFunction();
-  ConstantFoldHelper helper(&f);
+  FoldHelper helper(&f);
 
   // Collect and fold the operations within the function.
   SmallVector<Operation *, 8> ops;
@@ -74,12 +69,6 @@ void TestConstantFold::runOnFunction() {
   for (Operation *op : llvm::reverse(ops))
     foldOperation(op, helper);
 
-  // At this point, these operations are dead, remove them.
-  for (auto *op : opsToErase) {
-    assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
-    op->erase();
-  }
-
   // By the time we are done, we may have simplified a bunch of code, leaving
   // around dead constants.  Check for them now and remove them.
   for (auto *cst : existingConstants) {
similarity index 75%
rename from mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp
rename to mlir/lib/Transforms/Utils/FoldUtils.cpp
index b907840..578b822 100644 (file)
@@ -1,4 +1,4 @@
-//===- ConstantFoldUtils.cpp ---- Constant Fold Utilities -----------------===//
+//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
 // limitations under the License.
 // =============================================================================
 //
-// This file defines various constant fold utilities. These utilities are
+// This file defines various operation fold utilities. These utilities are
 // intended to be used by passes to unify and simply their logic.
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Transforms/ConstantFoldUtils.h"
+#include "mlir/Transforms/FoldUtils.h"
 
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Matchers.h"
 
 using namespace mlir;
 
-ConstantFoldHelper::ConstantFoldHelper(Function *f) : function(f) {}
+FoldHelper::FoldHelper(Function *f) : function(f) {}
 
-bool ConstantFoldHelper::tryToConstantFold(
-    Operation *op, std::function<void(Operation *)> preReplaceAction) {
+LogicalResult
+FoldHelper::tryToFold(Operation *op,
+                      std::function<void(Operation *)> preReplaceAction) {
   assert(op->getFunction() == function &&
          "cannot constant fold op from another function");
 
@@ -44,13 +45,15 @@ bool ConstantFoldHelper::tryToConstantFold(
     // If this constant is dead, update bookkeeping and signal the caller.
     if (constant.use_empty()) {
       notifyRemoval(op);
-      return true;
+      op->erase();
+      return success();
     }
     // Otherwise, try to see if we can de-duplicate it.
     return tryToUnify(op);
   }
 
-  SmallVector<Attribute, 8> operandConstants, resultConstants;
+  SmallVector<Attribute, 8> operandConstants;
+  SmallVector<OpFoldResult, 8> results;
 
   // Check to see if any operands to the operation is constant and whether
   // the operation knows how to constant fold itself.
@@ -67,8 +70,8 @@ bool ConstantFoldHelper::tryToConstantFold(
   }
 
   // Attempt to constant fold the operation.
-  if (failed(op->constantFold(operandConstants, resultConstants)))
-    return false;
+  if (failed(op->fold(operandConstants, results)))
+    return failure();
 
   // Constant folding succeeded. We will start replacing this op's uses and
   // eventually erase this op. Invoke the callback provided by the caller to
@@ -76,21 +79,35 @@ bool ConstantFoldHelper::tryToConstantFold(
   if (preReplaceAction)
     preReplaceAction(op);
 
+  // Check to see if the operation was just updated in place.
+  if (results.empty())
+    return success();
+  assert(results.size() == op->getNumResults());
+
   // Create the result constants and replace the results.
   FuncBuilder builder(op);
   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
     auto *res = op->getResult(i);
     if (res->use_empty()) // Ignore dead uses.
       continue;
+    assert(!results[i].isNull() && "expected valid OpFoldResult");
+
+    // Check if the result was an SSA value.
+    if (auto *repl = results[i].dyn_cast<Value *>()) {
+      if (repl != res)
+        res->replaceAllUsesWith(repl);
+      continue;
+    }
 
     // If we already have a canonicalized version of this constant, just reuse
     // it.  Otherwise create a new one.
+    Attribute attrRepl = results[i].get<Attribute>();
     auto &constInst =
-        uniquedConstants[std::make_pair(resultConstants[i], res->getType())];
+        uniquedConstants[std::make_pair(attrRepl, res->getType())];
     if (!constInst) {
       // TODO: Extend to support dialect-specific constant ops.
-      auto newOp = builder.create<ConstantOp>(op->getLoc(), res->getType(),
-                                              resultConstants[i]);
+      auto newOp =
+          builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl);
       // Register to the constant map and also move up to entry block to
       // guarantee dominance.
       constInst = newOp.getOperation();
@@ -98,17 +115,18 @@ bool ConstantFoldHelper::tryToConstantFold(
     }
     res->replaceAllUsesWith(constInst->getResult(0));
   }
+  op->erase();
 
-  return true;
+  return success();
 }
 
-void ConstantFoldHelper::notifyRemoval(Operation *op) {
+void FoldHelper::notifyRemoval(Operation *op) {
   assert(op->getFunction() == function &&
          "cannot remove constant from another function");
 
   Attribute constValue;
-  matchPattern(op, m_Constant(&constValue));
-  assert(constValue);
+  if (!matchPattern(op, m_Constant(&constValue)))
+    return;
 
   // This constant is dead. keep uniquedConstants up to date.
   auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
@@ -116,7 +134,7 @@ void ConstantFoldHelper::notifyRemoval(Operation *op) {
     uniquedConstants.erase(it);
 }
 
-bool ConstantFoldHelper::tryToUnify(Operation *op) {
+LogicalResult FoldHelper::tryToUnify(Operation *op) {
   Attribute constValue;
   matchPattern(op, m_Constant(&constValue));
   assert(constValue);
@@ -127,13 +145,14 @@ bool ConstantFoldHelper::tryToUnify(Operation *op) {
   if (constInst) {
     // If this constant is already our uniqued one, then leave it alone.
     if (constInst == op)
-      return false;
+      return failure();
 
     // Otherwise replace this redundant constant with the uniqued one.  We know
     // this is safe because we move constants to the top of the function when
     // they are uniqued, so we know they dominate all uses.
     op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
-    return true;
+    op->erase();
+    return success();
   }
 
   // If we have no entry, then we should unique this constant as the
@@ -141,10 +160,10 @@ bool ConstantFoldHelper::tryToUnify(Operation *op) {
   // entry block of the function.
   constInst = op;
   moveConstantToEntryBlock(op);
-  return false;
+  return failure();
 }
 
-void ConstantFoldHelper::moveConstantToEntryBlock(Operation *op) {
+void FoldHelper::moveConstantToEntryBlock(Operation *op) {
   // Insert at the very top of the entry block.
   auto &entryBB = function->front();
   op->moveBefore(&entryBB, entryBB.begin());
index fbdee58..58940c1 100644 (file)
@@ -22,7 +22,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/StandardOps/Ops.h"
-#include "mlir/Transforms/ConstantFoldUtils.h"
+#include "mlir/Transforms/FoldUtils.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -148,7 +148,7 @@ private:
 /// Perform the rewrites.
 bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
   Function *fn = builder.getFunction();
-  ConstantFoldHelper helper(fn);
+  FoldHelper helper(fn);
 
   bool changed = false;
   int i = 0;
@@ -171,67 +171,31 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
       // If the operation has no side effects, and no users, then it is
       // trivially dead - remove it.
       if (op->hasNoSideEffect() && op->use_empty()) {
-        // Be careful to update bookkeeping in ConstantHelper to keep
-        // consistency if this is a constant op.
-        if (isa<ConstantOp>(op))
-          helper.notifyRemoval(op);
+        // Be careful to update bookkeeping in FoldHelper to keep consistency if
+        // this is a constant op.
+        helper.notifyRemoval(op);
         op->erase();
         continue;
       }
 
       // Collects all the operands and result uses of the given `op` into work
       // list.
-      auto collectOperandsAndUses = [this](Operation *op) {
+      originalOperands.assign(op->operand_begin(), op->operand_end());
+      auto collectOperandsAndUses = [&](Operation *op) {
         // Add the operands to the worklist for visitation.
-        addToWorklist(op->getOperands());
+        addToWorklist(originalOperands);
+
         // Add all the users of the result to the worklist so we make sure
         // to revisit them.
         //
         // TODO: Add a result->getUsers() iterator.
-        for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
+        for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
           for (auto &operand : op->getResult(i)->getUses())
             addToWorklist(operand.getOwner());
-        }
       };
 
-      // Try to constant fold this op.
-      if (helper.tryToConstantFold(op, collectOperandsAndUses)) {
-        assert(op->hasNoSideEffect() &&
-               "Constant folded op with side effects?");
-        op->erase();
-        changed |= true;
-        continue;
-      }
-
-      // Otherwise see if we can use the generic folder API to simplify the
-      // operation.
-      originalOperands.assign(op->operand_begin(), op->operand_end());
-      resultValues.clear();
-      if (succeeded(op->fold(resultValues))) {
-        // If the result was an in-place simplification (e.g. max(x,x,y) ->
-        // max(x,y)) then add the original operands to the worklist so we can
-        // make sure to revisit them.
-        if (resultValues.empty()) {
-          // Add the operands back to the worklist as there may be more
-          // canonicalization opportunities now.
-          addToWorklist(originalOperands);
-        } else {
-          // Otherwise, the operation is simplified away completely.
-          assert(resultValues.size() == op->getNumResults());
-
-          // Notify that we are replacing this operation.
-          notifyRootReplaced(op);
-
-          // Replace the result values and erase the operation.
-          for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
-            auto *res = op->getResult(i);
-            if (!res->use_empty())
-              res->replaceAllUsesWith(resultValues[i]);
-          }
-
-          notifyOperationRemoved(op);
-          op->erase();
-        }
+      // Try to fold this op.
+      if (succeeded(helper.tryToFold(op, collectOperandsAndUses))) {
         changed |= true;
         continue;
       }
index 528caf6..a7583da 100644 (file)
@@ -28,7 +28,6 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
   let verifier = [{ baz }];
 
   let hasCanonicalizer = 1;
-  let hasConstantFolder = 1;
   let hasFolder = 1;
 
   let extraClassDeclaration = [{
@@ -55,8 +54,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
 // CHECK:   void print(OpAsmPrinter *p);
 // CHECK:   LogicalResult verify();
 // CHECK:   static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context);
-// CHECK:   LogicalResult constantFold(ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, MLIRContext *context);
-// CHECK:   bool fold(SmallVectorImpl<Value *> &results);
+// CHECK:   LogicalResult fold(ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results);
 // CHECK:   // Display a graph for debugging purposes.
 // CHECK:   void displayGraph();
 // CHECK: };
index b877322..f7ba069 100644 (file)
@@ -834,28 +834,15 @@ void OpEmitter::genCanonicalizerDecls() {
 void OpEmitter::genFolderDecls() {
   bool hasSingleResult = op.getNumResults() == 1;
 
-  if (def.getValueAsBit("hasConstantFolder")) {
-    if (hasSingleResult) {
-      const char *const params =
-          "ArrayRef<Attribute> operands, MLIRContext *context";
-      opClass.newMethod("Attribute", "constantFold", params, OpMethod::MP_None,
-                        /*declOnly=*/true);
-    } else {
-      const char *const params =
-          "ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, "
-          "MLIRContext *context";
-      opClass.newMethod("LogicalResult", "constantFold", params,
-                        OpMethod::MP_None, /*declOnly=*/true);
-    }
-  }
-
   if (def.getValueAsBit("hasFolder")) {
     if (hasSingleResult) {
-      opClass.newMethod("Value *", "fold", /*params=*/"", OpMethod::MP_None,
+      const char *const params = "ArrayRef<Attribute> operands";
+      opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None,
                         /*declOnly=*/true);
     } else {
-      opClass.newMethod("bool", "fold", "SmallVectorImpl<Value *> &results",
-                        OpMethod::MP_None,
+      const char *const params = "ArrayRef<Attribute> operands, "
+                                 "SmallVectorImpl<OpFoldResult> &results";
+      opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None,
                         /*declOnly=*/true);
     }
   }