Add getRemappedValue to ConversionPatternRewriter
authorDiego Caballero <diego.caballero@intel.com>
Tue, 19 Nov 2019 18:15:36 +0000 (10:15 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 19 Nov 2019 19:09:39 +0000 (11:09 -0800)
This method is needed for N->1 conversion patterns to retrieve remapped
Values used in the original N operations.

Closes tensorflow/mlir#237

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/237 from dcaballe:dcaballe/getRemappedValue 1f64fadcf2b203f7b336ff0c5838b116ae3625db
PiperOrigin-RevId: 281321881

mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalize-remapped-value.mlir [new file with mode: 0644]
mlir/test/lib/TestDialect/TestPatterns.cpp

index 2deb0c9..8866950 100644 (file)
@@ -338,6 +338,10 @@ public:
     return cast<OpT>(cloneWithoutRegions(op.getOperation()));
   }
 
+  /// Return the converted value that replaces 'key'. Return 'key' if there is
+  /// no such a converted value.
+  Value *getRemappedValue(Value *key);
+
   //===--------------------------------------------------------------------===//
   // PatternRewriter Hooks
   //===--------------------------------------------------------------------===//
index a2065f1..7931932 100644 (file)
@@ -803,6 +803,12 @@ Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
   return newOp;
 }
 
+/// Return the converted value that replaces 'key'. Return 'key' if there is
+/// no such a converted value.
+Value *ConversionPatternRewriter::getRemappedValue(Value *key) {
+  return impl->mapping.lookupOrDefault(key);
+}
+
 /// PatternRewriter hook for splitting a block into two parts.
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
diff --git a/mlir/test/Transforms/test-legalize-remapped-value.mlir b/mlir/test/Transforms/test-legalize-remapped-value.mlir
new file mode 100644 (file)
index 0000000..ff571c9
--- /dev/null
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -test-remapped-value | FileCheck %s
+
+// Simple test that exercises ConvertPatternRewriter::getRemappedValue.
+func @remap_input_1_to_1(%arg0: i32) {
+  %0 = "test.one_variadic_out_one_variadic_in1"(%arg0) : (i32) -> i32
+  %1 = "test.one_variadic_out_one_variadic_in1"(%0) : (i32) -> i32
+  "test.return"() : () -> ()
+}
+// CHECK-LABEL: func @remap_input_1_to_1
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+// CHECK-NEXT: %[[VAL:.*]] = "test.one_variadic_out_one_variadic_in1"(%[[ARG]], %[[ARG]])
+// CHECK-NEXT: "test.one_variadic_out_one_variadic_in1"(%[[VAL]], %[[VAL]])
+
index 936d763..5ef0360 100644 (file)
@@ -435,3 +435,66 @@ static mlir::PassRegistration<TestLegalizePatternDriver>
                      return std::make_unique<TestLegalizePatternDriver>(
                          legalizerConversionMode);
                    });
+
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriter::getRemappedValue testing. This method is used
+// to get the remapped value of a original value that was replaced using
+// ConversionPatternRewriter.
+namespace {
+/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
+/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
+/// operand twice.
+///
+/// Example:
+///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
+/// is replaced with:
+///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
+struct OneVResOneVOperandOp1Converter
+    : public OpConversionPattern<OneVResOneVOperandOp1> {
+  using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
+
+  PatternMatchResult
+  matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto origOps = op.getOperands();
+    assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
+           "One operand expected");
+    Value *origOp = *origOps.begin();
+    SmallVector<Value *, 2> remappedOperands;
+    // Replicate the remapped original operand twice. Note that we don't used
+    // the remapped 'operand' since the goal is testing 'getRemappedValue'.
+    remappedOperands.push_back(rewriter.getRemappedValue(origOp));
+    remappedOperands.push_back(rewriter.getRemappedValue(origOp));
+
+    SmallVector<Type, 1> resultTypes(op.getResultTypes());
+    rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes,
+                                                       remappedOperands);
+    return matchSuccess();
+  }
+};
+
+struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> {
+  void runOnFunction() override {
+    mlir::OwningRewritePatternList patterns;
+    patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
+
+    mlir::ConversionTarget target(getContext());
+    target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
+    // We make OneVResOneVOperandOp1 legal only when it has more that one
+    // operand. This will trigger the conversion that will replace one-operand
+    // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
+    target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
+        [](Operation *op) -> bool {
+          return std::distance(op->operand_begin(), op->operand_end()) > 1;
+        });
+
+    if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
+      signalPassFailure();
+    }
+  }
+};
+} // end anonymous namespace
+
+static PassRegistration<TestRemappedValue> remapped_value_pass(
+    "test-remapped-value",
+    "Test public remapped value mechanism in ConversionPatternRewriter");