Add new 'createOrFold' methods to FuncBuilder to immediately try to fold an operation...
authorRiver Riddle <riverriddle@google.com>
Wed, 5 Jun 2019 17:50:10 +0000 (10:50 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:18:55 +0000 (16:18 -0700)
PiperOrigin-RevId: 251674299

mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/Builders.cpp
mlir/lib/Linalg/Utils/Utils.cpp

index 09eaf56..52a78ed 100644 (file)
@@ -19,7 +19,7 @@
 #define MLIR_IR_BUILDERS_H
 
 #include "mlir/IR/Function.h"
-#include "mlir/IR/Operation.h"
+#include "mlir/IR/OpDefinition.h"
 
 namespace mlir {
 
@@ -297,7 +297,7 @@ public:
   /// Creates an operation given the fields represented as an OperationState.
   virtual Operation *createOperation(const OperationState &state);
 
-  /// Create operation of specific op type at the current insertion point.
+  /// Create an operation of specific op type at the current insertion point.
   template <typename OpTy, typename... Args>
   OpTy create(Location location, Args... args) {
     OperationState state(getContext(), location, OpTy::getOperationName());
@@ -308,6 +308,40 @@ public:
     return result;
   }
 
+  /// Create an operation of specific op type at the current insertion point,
+  /// and immediately try to fold it. This functions populates 'results' with
+  /// the results after folding the operation.
+  template <typename OpTy, typename... Args>
+  void createOrFold(SmallVectorImpl<Value *> &results, Location location,
+                    Args &&... args) {
+    auto op = create<OpTy>(location, std::forward<Args>(args)...);
+    tryFold(op.getOperation(), results);
+  }
+
+  /// Overload to create or fold a single result operation.
+  template <typename OpTy, typename... Args>
+  typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
+                          Value *>::type
+  createOrFold(Location location, Args &&... args) {
+    SmallVector<Value *, 1> results;
+    createOrFold<OpTy>(results, location, std::forward<Args>(args)...);
+    return results.front();
+  }
+
+  /// Overload to create or fold a zero result operation.
+  template <typename OpTy, typename... Args>
+  typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
+                          OpTy>::type
+  createOrFold(Location location, Args &&... args) {
+    auto op = create<OpTy>(location, std::forward<Args>(args)...);
+    SmallVector<Value *, 0> unused;
+    tryFold(op.getOperation(), unused);
+
+    // Folding cannot remove a zero-result operation, so for convenience we
+    // continue to return it.
+    return op;
+  }
+
   /// Creates a deep copy of the specified operation, remapping any operands
   /// that use values outside of the operation using the map that is provided
   /// ( leaving them alone if no entry is present).  Replaces references to
@@ -339,6 +373,10 @@ public:
   }
 
 private:
+  /// Attempts to fold the given operation and places new results within
+  /// 'results'.
+  void tryFold(Operation *op, SmallVectorImpl<Value *> &results);
+
   Region *region;
   Block *block = nullptr;
   Block::iterator insertPoint;
index ea448cc..ad62c59 100644 (file)
@@ -194,7 +194,7 @@ public:
 
   /// This hook implements a generalized folder for this operation.  Operations
   /// can implement this to provide simplifications rules that are applied by
-  /// the Builder::foldOrCreate API and the canonicalization pass.
+  /// the Builder::createOrFold API and the canonicalization pass.
   ///
   /// This is an intentionally limited interface - implementations of this hook
   /// can only perform the following changes to the operation:
@@ -250,7 +250,7 @@ public:
 
   /// This hook implements a generalized folder for this operation.  Operations
   /// can implement this to provide simplifications rules that are applied by
-  /// the Builder::foldOrCreate API and the canonicalization pass.
+  /// the Builder::createOrFold API and the canonicalization pass.
   ///
   /// This is an intentionally limited interface - implementations of this hook
   /// can only perform the following changes to the operation:
index de53bf6..a5adad7 100644 (file)
@@ -106,7 +106,7 @@ public:
 
   /// This hook implements a generalized folder for this operation.  Operations
   /// can implement this to provide simplifications rules that are applied by
-  /// the Builder::foldOrCreate API and the canonicalization pass.
+  /// the Builder::createOrFold API and the canonicalization pass.
   ///
   /// This is an intentionally limited interface - implementations of this hook
   /// can only perform the following changes to the operation:
index d32e705..ac06609 100644 (file)
@@ -362,3 +362,29 @@ Operation *OpBuilder::createOperation(const OperationState &state) {
   block->getOperations().insert(insertPoint, op);
   return op;
 }
+
+/// Attempts to fold the given operation and places new results within
+/// 'results'.
+void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
+  results.reserve(op->getNumResults());
+  SmallVector<OpFoldResult, 4> foldResults;
+
+  // Returns if the given fold result corresponds to a valid existing value.
+  auto isValidValue = [](OpFoldResult result) {
+    return result.dyn_cast<Value *>();
+  };
+
+  // Check if the fold failed, or did not result in only existing values.
+  SmallVector<Attribute, 4> constOperands(op->getNumOperands());
+  if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() ||
+      !llvm::all_of(foldResults, isValidValue)) {
+    // Simply return the existing operation results.
+    results.assign(op->result_begin(), op->result_end());
+    return;
+  }
+
+  // Populate the results with the folded results and remove the original op.
+  llvm::transform(foldResults, std::back_inserter(results),
+                  [](OpFoldResult result) { return result.get<Value *>(); });
+  op->erase();
+}
index 81fad1c..eec85ff 100644 (file)
@@ -93,31 +93,15 @@ SmallVector<Value *, 8> mlir::linalg::getViewSizes(LinalgOp &linalgOp) {
   return res;
 }
 
-// Folding eagerly is necessary to abide by affine.for static step requirement.
-// We must propagate constants on the steps as aggressively as possible.
-// Returns nullptr if folding is not trivially feasible.
-static Value *tryFold(AffineMap map, ArrayRef<Value *> operands,
-                      FunctionConstants &state) {
-  assert(map.getNumResults() == 1 && "single result map expected");
-  auto expr = map.getResult(0);
-  if (auto dim = expr.dyn_cast<AffineDimExpr>())
-    return operands[dim.getPosition()];
-  if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
-    return operands[map.getNumDims() + sym.getPosition()];
-  if (auto cst = expr.dyn_cast<AffineConstantExpr>())
-    return state.getOrCreateIndex(cst.getValue());
-  return nullptr;
-}
-
 static Value *emitOrFoldComposedAffineApply(OpBuilder *b, Location loc,
                                             AffineMap map,
                                             ArrayRef<Value *> operandsRef,
                                             FunctionConstants &state) {
   SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
   fullyComposeAffineMapAndOperands(&map, &operands);
-  if (auto *v = tryFold(map, operands, state))
-    return v;
-  return b->create<AffineApplyOp>(loc, map, operands);
+  if (auto cst = map.getResult(0).dyn_cast<AffineConstantExpr>())
+    return state.getOrCreateIndex(cst.getValue());
+  return b->createOrFold<AffineApplyOp>(loc, map, operands);
 }
 
 SmallVector<Value *, 4>