Add support for an analysis mode to DialectConversion.
authorRiver Riddle <riverriddle@google.com>
Thu, 25 Jul 2019 18:30:41 +0000 (11:30 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 25 Jul 2019 18:31:07 +0000 (11:31 -0700)
This mode analyzes which operations are legalizable to the given target if a conversion were to be applied, i.e. no rewrites are ever performed even on success. This mode is useful for device partitioning or other utilities that may want to analyze the effect of conversion to different targets before performing it.

The analysis method currently just fills a provided set with the operations that were found to be legalizable. This can be extended in the future to capture more information as necessary.

PiperOrigin-RevId: 259987105

mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalizer-analysis.mlir [new file with mode: 0644]
mlir/test/lib/TestDialect/TestPatterns.cpp

index 1ffd5bb..0d69ed0 100644 (file)
@@ -487,6 +487,25 @@ LLVM_NODISCARD LogicalResult applyFullConversion(
 LLVM_NODISCARD LogicalResult applyFullConversion(
     Operation *op, ConversionTarget &target,
     OwningRewritePatternList &&patterns, TypeConverter *converter = nullptr);
+
+/// Apply an analysis conversion on the given operations, and all nested
+/// operations. This method analyzes which operations would be successfully
+/// converted to the target if a conversion was applied. All operations that
+/// were found to be legalizable to the given 'target' are placed within the
+/// provided 'convertedOps' set; note that no actual rewrites are applied to the
+/// operations on success and only pre-existing operations are added to the set.
+/// This method only returns failure if there are unreachable blocks in any of
+/// the regions nested within 'ops', or if a type conversion failed. If
+/// 'converter' is provided, the signatures of blocks and regions are also
+/// considered for conversion.
+LLVM_NODISCARD LogicalResult applyAnalysisConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    OwningRewritePatternList &&patterns, DenseSet<Operation *> &convertedOps,
+    TypeConverter *converter = nullptr);
+LLVM_NODISCARD LogicalResult applyAnalysisConversion(
+    Operation *op, ConversionTarget &target,
+    OwningRewritePatternList &&patterns, DenseSet<Operation *> &convertedOps,
+    TypeConverter *converter = nullptr);
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
index aac2e11..8f2518e 100644 (file)
@@ -952,6 +952,10 @@ enum OpConversionMode {
   // In this mode, all operations must be legal for the given target for the
   // conversion to succeeed.
   Full,
+
+  // In this mode, operations are analyzed for legality. No actual rewrites are
+  // applied to the operations on success.
+  Analysis,
 };
 
 // This class converts operations using the given pattern matcher. If a
@@ -960,8 +964,10 @@ enum OpConversionMode {
 struct OperationConverter {
   explicit OperationConverter(ConversionTarget &target,
                               OwningRewritePatternList &patterns,
-                              OpConversionMode mode)
-      : opLegalizer(target, patterns), mode(mode) {}
+                              OpConversionMode mode,
+                              DenseSet<Operation *> *legalizableOps = nullptr)
+      : opLegalizer(target, patterns), mode(mode),
+        legalizableOps(legalizableOps) {}
 
   /// Converts the given operations to the conversion target.
   LogicalResult convertOperations(ArrayRef<Operation *> ops,
@@ -985,6 +991,10 @@ private:
 
   /// The conversion mode to use when legalizing operations.
   OpConversionMode mode;
+
+  /// A set of pre-existing operations that were found to be legalizable to the
+  /// target. This field is only used when mode == OpConversionMode::Analysis.
+  DenseSet<Operation *> *legalizableOps;
 };
 } // end anonymous namespace
 
@@ -1055,6 +1065,10 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
       return op->emitError()
              << "failed to legalize operation '" << op->getName()
              << "' that was explicitly marked illegal";
+  } else if (mode == OpConversionMode::Analysis) {
+    /// Analysis conversions don't fail if any operations fail to legalize, they
+    /// are only interested in the operations that were successfully legalized.
+    legalizableOps->insert(op);
   }
   return success();
 }
@@ -1088,8 +1102,12 @@ OperationConverter::convertOperations(ArrayRef<Operation *> ops,
         return rewriter.getImpl().discardRewrites(), failure();
   }
 
-  // Otherwise the body conversion succeeded, so apply all rewrites.
-  rewriter.getImpl().applyRewrites();
+  // Otherwise, the body conversion succeeded. Apply rewrites if this is not an
+  // analysis conversion.
+  if (mode == OpConversionMode::Analysis)
+    rewriter.getImpl().discardRewrites();
+  else
+    rewriter.getImpl().applyRewrites();
   return success();
 }
 
@@ -1348,3 +1366,27 @@ LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target,
   return applyFullConversion(llvm::makeArrayRef(op), target,
                              std::move(patterns), converter);
 }
+
+/// Apply an analysis conversion on the given operations, and all nested
+/// operations. This method analyzes which operations would be successfully
+/// converted to the target if a conversion was applied. All operations that
+/// were found to be legalizable to the given 'target' are placed within the
+/// provided 'convertedOps' set; note that no actual rewrites are applied to the
+/// operations on success and only pre-existing operations are added to the set.
+LogicalResult mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
+                                            ConversionTarget &target,
+                                            OwningRewritePatternList &&patterns,
+                                            DenseSet<Operation *> &convertedOps,
+                                            TypeConverter *converter) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
+                                 &convertedOps);
+  return opConverter.convertOperations(ops, converter);
+}
+LogicalResult mlir::applyAnalysisConversion(Operation *op,
+                                            ConversionTarget &target,
+                                            OwningRewritePatternList &&patterns,
+                                            DenseSet<Operation *> &convertedOps,
+                                            TypeConverter *converter) {
+  return applyAnalysisConversion(llvm::makeArrayRef(op), target,
+                                 std::move(patterns), convertedOps, converter);
+}
diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir
new file mode 100644 (file)
index 0000000..347e058
--- /dev/null
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
+
+// expected-remark@+1 {{op 'func' is legalizable}}
+func @test(%arg0: f32) {
+  // expected-remark@+1 {{op 'test.illegal_op_a' is legalizable}}
+  %result = "test.illegal_op_a"() : () -> (i32)
+  "foo.region"() ({
+      // expected-remark@+1 {{op 'test.invalid' is legalizable}}
+      "test.invalid"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// Check that none of the legalizable operations were modified.
+// CHECK-LABEL: func @test
+// CHECK-NEXT: "test.illegal_op_a"
+// CHECK: "test.invalid"
index 1cbd253..201dfc3 100644 (file)
@@ -184,6 +184,11 @@ struct TestTypeConverter : public TypeConverter {
 
 struct TestLegalizePatternDriver
     : public ModulePass<TestLegalizePatternDriver> {
+  /// The mode of conversion to use with the driver.
+  enum class ConversionMode { Analysis, Partial };
+
+  TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
+
   void runOnModule() override {
     TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
@@ -205,12 +210,44 @@ struct TestLegalizePatternDriver
     });
     target.addDynamicallyLegalOp<FuncOp>(
         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
-    (void)applyPartialConversion(getModule(), target, std::move(patterns),
-                                 &converter);
+
+    // Handle a partial conversion.
+    if (mode == ConversionMode::Partial) {
+      (void)applyPartialConversion(getModule(), target, std::move(patterns),
+                                   &converter);
+      return;
+    }
+
+    // Otherwise, handle an analysis conversion.
+    assert(mode == ConversionMode::Analysis);
+
+    // Analyze the convertible operations.
+    DenseSet<Operation *> legalizedOps;
+    if (failed(applyAnalysisConversion(getModule(), target, std::move(patterns),
+                                       legalizedOps, &converter)))
+      return signalPassFailure();
+
+    // Emit remarks for each legalizable operation.
+    for (auto *op : legalizedOps)
+      op->emitRemark() << "op '" << op->getName() << "' is legalizable";
   }
+
+  /// The mode of conversion to use.
+  ConversionMode mode;
 };
 } // end anonymous namespace
 
-static mlir::PassRegistration<TestLegalizePatternDriver>
-    legalizer_pass("test-legalize-patterns",
-                   "Run test dialect legalization patterns");
+static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
+    legalizerConversionMode(
+        "test-legalize-mode",
+        llvm::cl::desc("The legalization mode to use with the test driver"),
+        llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
+        llvm::cl::values(
+            clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
+                       "analysis", "Perform an analysis conversion"),
+            clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
+                       "partial", "Perform a partial conversion")));
+
+static mlir::PassRegistration<TestLegalizePatternDriver> legalizer_pass(
+    "test-legalize-patterns", "Run test dialect legalization patterns",
+    [] { return new TestLegalizePatternDriver(legalizerConversionMode); });