Try to fold operations in DialectConversion when trying to legalize.
authorRiver Riddle <riverriddle@google.com>
Fri, 13 Dec 2019 20:21:42 +0000 (12:21 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 14 Dec 2019 00:47:26 +0000 (16:47 -0800)
This change allows for DialectConversion to attempt folding as a mechanism to legalize illegal operations. This also expands folding support in OpBuilder::createOrFold to generate new constants when folding, and also enables it to work in the context of a PatternRewriter.

PiperOrigin-RevId: 285448440

13 files changed:
mlir/include/mlir/Dialect/AffineOps/AffineOps.h
mlir/include/mlir/Dialect/StandardOps/Ops.h
mlir/include/mlir/Dialect/VectorOps/VectorOps.h
mlir/include/mlir/IR/Builders.h
mlir/lib/Dialect/AffineOps/AffineOps.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/Transforms/lower-affine.mlir
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/TestDialect/TestPatterns.cpp

index 8d36473..835ac24 100644 (file)
@@ -47,6 +47,11 @@ class AffineOpsDialect : public Dialect {
 public:
   AffineOpsDialect(MLIRContext *context);
   static StringRef getDialectNamespace() { return "affine"; }
+
+  /// Materialize a single constant operation from a given attribute value with
+  /// the desired resultant type.
+  Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+                                 Location loc) override;
 };
 
 /// The "affine.apply" operation applies an affine map to a list of operands,
index d01a1ea..c7c8714 100644 (file)
@@ -42,6 +42,11 @@ class StandardOpsDialect : public Dialect {
 public:
   StandardOpsDialect(MLIRContext *context);
   static StringRef getDialectNamespace() { return "std"; }
+
+  /// Materialize a single constant operation from a given attribute value with
+  /// the desired resultant type.
+  Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+                                 Location loc) override;
 };
 
 /// The predicate indicates the type of the comparison to perform:
index 5b4351b..06672c7 100644 (file)
@@ -37,6 +37,11 @@ class VectorOpsDialect : public Dialect {
 public:
   VectorOpsDialect(MLIRContext *context);
   static StringRef getDialectNamespace() { return "vector"; }
+
+  /// Materialize a single constant operation from a given attribute value with
+  /// the desired resultant type.
+  Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+                                 Location loc) override;
 };
 
 /// Collect a set of vector-to-vector canonicalization patterns.
index 9c787c1..766902f 100644 (file)
@@ -315,8 +315,17 @@ public:
   template <typename OpTy, typename... Args>
   void createOrFold(SmallVectorImpl<Value *> &results, Location location,
                     Args &&... args) {
-    auto op = create<OpTy>(location, std::forward<Args>(args)...);
-    tryFold(op.getOperation(), results);
+    // Create the operation without using 'createOperation' as we don't want to
+    // insert it yet.
+    OperationState state(location, OpTy::getOperationName());
+    OpTy::build(this, state, std::forward<Args>(args)...);
+    Operation *op = Operation::create(state);
+
+    // Fold the operation. If successful destroy it, otherwise insert it.
+    if (succeeded(tryFold(op, results)))
+      op->destroy();
+    else
+      insert(op);
   }
 
   /// Overload to create or fold a single result operation.
@@ -343,6 +352,11 @@ public:
     return op;
   }
 
+  /// Attempts to fold the given operation and places new results within
+  /// 'results'. Returns success if the operation was folded, failure otherwise.
+  /// Note: This function does not erase the operation on a successful fold.
+  LogicalResult tryFold(Operation *op, SmallVectorImpl<Value *> &results);
+
   /// Creates a deep copy of the specified operation, remapping any operands
   /// that use values outside of the operation using the map that is provided
   /// ( leaving them alone if no entry is present).  Replaces references to
@@ -367,10 +381,6 @@ public:
   }
 
 private:
-  /// Attempts to fold the given operation and places new results within
-  /// 'results'.
-  void tryFold(Operation *op, SmallVectorImpl<Value *> &results);
-
   Block *block = nullptr;
   Block::iterator insertPoint;
 };
index 96a1a68..22d4ec1 100644 (file)
@@ -99,6 +99,14 @@ AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
   addInterfaces<AffineInlinerInterface, AffineSideEffectsInterface>();
 }
 
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *AffineOpsDialect::materializeConstant(OpBuilder &builder,
+                                                 Attribute value, Type type,
+                                                 Location loc) {
+  return builder.create<ConstantOp>(loc, type, value);
+}
+
 /// A utility function to check if a given region is attached to a function.
 static bool isFunctionRegion(Region *region) {
   return llvm::isa<FuncOp>(region->getParentOp());
index 531be29..713546f 100644 (file)
@@ -163,6 +163,14 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
   addInterfaces<StdInlinerInterface>();
 }
 
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
+                                                   Attribute value, Type type,
+                                                   Location loc) {
+  return builder.create<ConstantOp>(loc, type, value);
+}
+
 void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
                                  Operation::operand_iterator end,
                                  unsigned numDims, OpAsmPrinter &p) {
index a2345fe..ae5579d 100644 (file)
@@ -40,7 +40,7 @@ using namespace mlir::vector;
 // VectorOpsDialect
 //===----------------------------------------------------------------------===//
 
-mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
+VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
   addOperations<
 #define GET_OP_LIST
@@ -48,6 +48,14 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
       >();
 }
 
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *VectorOpsDialect::materializeConstant(OpBuilder &builder,
+                                                 Attribute value, Type type,
+                                                 Location loc) {
+  return builder.create<ConstantOp>(loc, type, value);
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOp
 //===----------------------------------------------------------------------===//
index 8c54df4..691b2ad 100644 (file)
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Support/Functional.h"
+#include "llvm/Support/raw_ostream.h"
 using namespace mlir;
 
 Builder::Builder(ModuleOp module) : context(module.getContext()) {}
@@ -339,27 +340,68 @@ Operation *OpBuilder::createOperation(const OperationState &state) {
 }
 
 /// Attempts to fold the given operation and places new results within
-/// 'results'.
-void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
+/// 'results'. Returns success if the operation was folded, failure otherwise.
+/// Note: This function does not erase the operation on a successful fold.
+LogicalResult OpBuilder::tryFold(Operation *op,
+                                 SmallVectorImpl<Value *> &results) {
   results.reserve(op->getNumResults());
-  SmallVector<OpFoldResult, 4> foldResults;
-
-  // Returns if the given fold result corresponds to a valid existing value.
-  auto isValidValue = [](OpFoldResult result) {
-    return result.dyn_cast<Value *>();
+  auto cleanupFailure = [&] {
+    results.assign(op->result_begin(), op->result_end());
+    return failure();
   };
 
-  // Check if the fold failed, or did not result in only existing values.
+  // If this operation is already a constant, there is nothing to do.
+  Attribute unused;
+  if (matchPattern(op, m_Constant(&unused)))
+    return cleanupFailure();
+
+  // Check to see if any operands to the operation is constant and whether
+  // the operation knows how to constant fold itself.
   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
-  if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() ||
-      !llvm::all_of(foldResults, isValidValue)) {
-    // Simply return the existing operation results.
-    results.assign(op->result_begin(), op->result_end());
-    return;
+  for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+    matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
+
+  // Try to fold the operation.
+  SmallVector<OpFoldResult, 4> foldResults;
+  if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
+    return cleanupFailure();
+
+  // A temporary builder used for creating constants during folding.
+  OpBuilder cstBuilder(context);
+  SmallVector<Operation *, 1> generatedConstants;
+
+  // Populate the results with the folded results.
+  Dialect *dialect = op->getDialect();
+  for (auto &it : llvm::enumerate(foldResults)) {
+    // Normal values get pushed back directly.
+    if (auto *value = it.value().dyn_cast<Value *>()) {
+      results.push_back(value);
+      continue;
+    }
+
+    // Otherwise, try to materialize a constant operation.
+    if (!dialect)
+      return cleanupFailure();
+
+    // Ask the dialect to materialize a constant operation for this value.
+    Attribute attr = it.value().get<Attribute>();
+    auto *constOp = dialect->materializeConstant(
+        cstBuilder, attr, op->getResult(it.index())->getType(), op->getLoc());
+    if (!constOp) {
+      // Erase any generated constants.
+      for (Operation *cst : generatedConstants)
+        cst->erase();
+      return cleanupFailure();
+    }
+    assert(matchPattern(constOp, m_Constant(&attr)));
+
+    generatedConstants.push_back(constOp);
+    results.push_back(constOp->getResult(0));
   }
 
-  // Populate the results with the folded results and remove the original op.
-  llvm::transform(foldResults, std::back_inserter(results),
-                  [](OpFoldResult result) { return result.get<Value *>(); });
-  op->erase();
+  // If we were successful, insert any generated constants.
+  for (Operation *cst : generatedConstants)
+    insert(cst);
+
+  return success();
 }
index ea4ad68..ac13bc2 100644 (file)
@@ -25,7 +25,6 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
-#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -938,6 +937,10 @@ public:
   ConversionTarget &getTarget() { return target; }
 
 private:
+  /// Attempt to legalize the given operation by folding it.
+  LogicalResult legalizeWithFold(Operation *op,
+                                 ConversionPatternRewriter &rewriter);
+
   /// Attempt to legalize the given operation by applying the provided pattern.
   /// Returns success if the operation was legalized, failure otherwise.
   LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
@@ -1003,6 +1006,14 @@ OperationLegalizer::legalize(Operation *op,
     return success();
   }
 
+  // If the operation isn't legal, try to fold it in-place.
+  // TODO(riverriddle) Should we always try to do this, even if the op is
+  // already legal?
+  if (succeeded(legalizeWithFold(op, rewriter))) {
+    LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n");
+    return success();
+  }
+
   // Otherwise, we need to apply a legalization pattern to this operation.
   auto it = legalizerPatterns.find(op->getName());
   if (it == legalizerPatterns.end()) {
@@ -1020,6 +1031,36 @@ OperationLegalizer::legalize(Operation *op,
 }
 
 LogicalResult
+OperationLegalizer::legalizeWithFold(Operation *op,
+                                     ConversionPatternRewriter &rewriter) {
+  auto &rewriterImpl = rewriter.getImpl();
+  RewriterState curState = rewriterImpl.getCurrentState();
+
+  // Try to fold the operation.
+  SmallVector<Value *, 2> replacementValues;
+  rewriter.setInsertionPoint(op);
+  if (failed(rewriter.tryFold(op, replacementValues)))
+    return failure();
+
+  // Insert a replacement for 'op' with the folded replacement values.
+  rewriter.replaceOp(op, replacementValues);
+
+  // Recursively legalize any new constant operations.
+  for (unsigned i = curState.numCreatedOperations,
+                e = rewriterImpl.createdOps.size();
+       i != e; ++i) {
+    Operation *cstOp = rewriterImpl.createdOps[i];
+    if (failed(legalize(cstOp, rewriter))) {
+      LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '"
+                              << cstOp->getName() << "' was illegal.\n");
+      rewriterImpl.resetState(curState);
+      return failure();
+    }
+  }
+  return success();
+}
+
+LogicalResult
 OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
                                     ConversionPatternRewriter &rewriter) {
   LLVM_DEBUG({
index 9d8b047..14e46aa 100644 (file)
@@ -364,9 +364,8 @@ func @multireturn_caller() {
   return
 }
 
-// CHECK-LABEL: func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
-func @vector_ops(vector<4xf32>, vector<4xi1>, vector<4xi64>) -> vector<4xf32> {
-^bb0(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>):
+// CHECK-LABEL: func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">, %arg3: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
+func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>, %arg3: vector<4xi64>) -> vector<4xf32> {
 // CHECK-NEXT:  %0 = llvm.mlir.constant(dense<4.200000e+01> : vector<4xf32>) : !llvm<"<4 x float>">
   %0 = constant dense<42.> : vector<4xf32>
 // CHECK-NEXT:  %1 = llvm.fadd %arg0, %0 : !llvm<"<4 x float>">
@@ -383,12 +382,12 @@ func @vector_ops(vector<4xf32>, vector<4xi1>, vector<4xi64>) -> vector<4xf32> {
   %7 = divf %arg0, %0 : vector<4xf32>
 // CHECK-NEXT:  %7 = llvm.frem %arg0, %0 : !llvm<"<4 x float>">
   %8 = remf %arg0, %0 : vector<4xf32>
-// CHECK-NEXT:  %8 = llvm.and %arg2, %arg2 : !llvm<"<4 x i64>">
-  %9 = and %arg2, %arg2 : vector<4xi64>
-// CHECK-NEXT:  %9 = llvm.or %arg2, %arg2 : !llvm<"<4 x i64>">
-  %10 = or %arg2, %arg2 : vector<4xi64>
-// CHECK-NEXT:  %10 = llvm.xor %arg2, %arg2 : !llvm<"<4 x i64>">
-  %11 = xor %arg2, %arg2 : vector<4xi64>
+// CHECK-NEXT:  %8 = llvm.and %arg2, %arg3 : !llvm<"<4 x i64>">
+  %9 = and %arg2, %arg3 : vector<4xi64>
+// CHECK-NEXT:  %9 = llvm.or %arg2, %arg3 : !llvm<"<4 x i64>">
+  %10 = or %arg2, %arg3 : vector<4xi64>
+// CHECK-NEXT:  %10 = llvm.xor %arg2, %arg3 : !llvm<"<4 x i64>">
+  %11 = xor %arg2, %arg3 : vector<4xi64>
   return %1 : vector<4xf32>
 }
 
@@ -498,23 +497,21 @@ func @integer_extension_and_truncation() {
 }
 
 // CHECK-LABEL: @dfs_block_order
-func @dfs_block_order() -> (i32) {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(42 : i32) : !llvm.i32
+func @dfs_block_order(%arg0: i32) -> (i32) {
+// CHECK-NEXT:  %[[CST:.*]] = llvm.mlir.constant(42 : i32) : !llvm.i32
   %0 = constant 42 : i32
 // CHECK-NEXT:  llvm.br ^bb2
   br ^bb2
 
 // CHECK-NEXT: ^bb1:
-// CHECK-NEXT:  %1 = llvm.add %0, %2 : !llvm.i32
-// CHECK-NEXT:  llvm.return %1 : !llvm.i32
+// CHECK-NEXT:  %[[ADD:.*]] = llvm.add %arg0, %[[CST]] : !llvm.i32
+// CHECK-NEXT:  llvm.return %[[ADD]] : !llvm.i32
 ^bb1:
-  %2 = addi %0, %1 : i32
+  %2 = addi %arg0, %0 : i32
   return %2 : i32
 
 // CHECK-NEXT: ^bb2:
 ^bb2:
-// CHECK-NEXT:  %2 = llvm.mlir.constant(55 : i32) : !llvm.i32
-  %1 = constant 55 : i32
 // CHECK-NEXT:  llvm.br ^bb1
   br ^bb1
 }
index 5825ae5..ae933bc 100644 (file)
@@ -387,8 +387,7 @@ func @min_reduction_tree(%v : index) {
 #map6 = (d0,d1,d2) -> (d0 + d1 + d2)
 
 // CHECK-LABEL: func @affine_applies(
-func @affine_applies() {
-^bb0:
+func @affine_applies(%arg0 : index) {
 // CHECK: %[[c0:.*]] = constant 0 : index
   %zero = affine.apply #map0()
 
@@ -405,39 +404,25 @@ func @affine_applies() {
 // CHECK-NEXT: %[[v1:.*]] = addi %[[v0]], %[[c1]] : index
   %one = affine.apply #map3(%symbZero)[%zero]
 
-// CHECK-NEXT: %[[c103:.*]] = constant 103 : index
-// CHECK-NEXT: %[[c104:.*]] = constant 104 : index
-// CHECK-NEXT: %[[c105:.*]] = constant 105 : index
-// CHECK-NEXT: %[[c106:.*]] = constant 106 : index
-// CHECK-NEXT: %[[c107:.*]] = constant 107 : index
-// CHECK-NEXT: %[[c108:.*]] = constant 108 : index
-// CHECK-NEXT: %[[c109:.*]] = constant 109 : index
-  %103 = constant 103 : index
-  %104 = constant 104 : index
-  %105 = constant 105 : index
-  %106 = constant 106 : index
-  %107 = constant 107 : index
-  %108 = constant 108 : index
-  %109 = constant 109 : index
 // CHECK-NEXT: %[[c2:.*]] = constant 2 : index
-// CHECK-NEXT: %[[v2:.*]] = muli %[[c104]], %[[c2]] : index
-// CHECK-NEXT: %[[v3:.*]] = addi %[[c103]], %[[v2]] : index
+// CHECK-NEXT: %[[v2:.*]] = muli %arg0, %[[c2]] : index
+// CHECK-NEXT: %[[v3:.*]] = addi %arg0, %[[v2]] : index
 // CHECK-NEXT: %[[c3:.*]] = constant 3 : index
-// CHECK-NEXT: %[[v4:.*]] = muli %[[c105]], %[[c3]] : index
+// CHECK-NEXT: %[[v4:.*]] = muli %arg0, %[[c3]] : index
 // CHECK-NEXT: %[[v5:.*]] = addi %[[v3]], %[[v4]] : index
 // CHECK-NEXT: %[[c4:.*]] = constant 4 : index
-// CHECK-NEXT: %[[v6:.*]] = muli %[[c106]], %[[c4]] : index
+// CHECK-NEXT: %[[v6:.*]] = muli %arg0, %[[c4]] : index
 // CHECK-NEXT: %[[v7:.*]] = addi %[[v5]], %[[v6]] : index
 // CHECK-NEXT: %[[c5:.*]] = constant 5 : index
-// CHECK-NEXT: %[[v8:.*]] = muli %[[c107]], %[[c5]] : index
+// CHECK-NEXT: %[[v8:.*]] = muli %arg0, %[[c5]] : index
 // CHECK-NEXT: %[[v9:.*]] = addi %[[v7]], %[[v8]] : index
 // CHECK-NEXT: %[[c6:.*]] = constant 6 : index
-// CHECK-NEXT: %[[v10:.*]] = muli %[[c108]], %[[c6]] : index
+// CHECK-NEXT: %[[v10:.*]] = muli %arg0, %[[c6]] : index
 // CHECK-NEXT: %[[v11:.*]] = addi %[[v9]], %[[v10]] : index
 // CHECK-NEXT: %[[c7:.*]] = constant 7 : index
-// CHECK-NEXT: %[[v12:.*]] = muli %[[c109]], %[[c7]] : index
+// CHECK-NEXT: %[[v12:.*]] = muli %arg0, %[[c7]] : index
 // CHECK-NEXT: %[[v13:.*]] = addi %[[v11]], %[[v12]] : index
-  %four = affine.apply #map4(%103,%104,%105,%106)[%107,%108,%109]
+  %four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
   return
 }
 
index efb59b0..38f87dd 100644 (file)
@@ -113,6 +113,16 @@ func @up_to_date_replacement(%arg: i8) -> i8 {
   return %repl_2 : i8
 }
 
+// CHECK-LABEL: func @remove_foldable_op
+// CHECK-SAME:                          (%[[ARG_0:[a-z0-9]*]]: i32)
+func @remove_foldable_op(%arg0 : i32) -> (i32) {
+  // CHECK-NEXT: return %[[ARG_0]]
+  %0 = "test.op_with_region_fold"(%arg0) ({
+    "foo.op_with_region_terminator"() : () -> ()
+  }) : (i32) -> (i32)
+  return %0 : i32
+}
+
 // -----
 
 func @fail_to_convert_illegal_op() -> i32 {
index 9d85c7d..94eb792 100644 (file)
@@ -375,7 +375,8 @@ struct TestLegalizePatternDriver
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
     target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
-    target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
+    target
+        .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
       // Don't allow F32 operands.
       return llvm::none_of(op.getOperandTypes(),