[mlir] add transform dialect entry point
authorAlex Zinenko <zinenko@google.com>
Tue, 11 Oct 2022 15:23:48 +0000 (15:23 +0000)
committerAlex Zinenko <zinenko@google.com>
Wed, 12 Oct 2022 08:16:28 +0000 (08:16 +0000)
Introduce `transform::applyTransforms` as a top-level entry point to the
Transform dialect-driven transformation infrastructure, by analogy with
`applyFull/PartialConversion`. Clients are expected to use this function
and no longer need to maintain the transformation state. Make the
constructor of the TransformState private for that purpose.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D135681

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp

index cf5072b..5fef82b 100644 (file)
@@ -16,19 +16,18 @@ def Transform_Dialect : Dialect {
   let description = [{
     ## Disclaimer
 
-    ** Proceed with care: not ready for general use. **
+    **This dialect is actively developed and may change frequently.**
 
-    This dialect is evolving rapidly and may change on a very short notice. To
-    decrease the maintenance burden and churn, only a few in-tree use cases are
-    currently supported in the main tree:
+    To decrease the maintenance burden and churn, please post a description of
+    the intended use case on the MLIR forum. A few in-tree use cases are
+    currently supported:
 
       - high-level transformations on "structured ops" (i.e. ops that operate on
         chunks of data in a way that can be decomposed into operations on
         smaller chunks of data and control flow) in Linalg, Tensor and Vector
-        dialects.
-
-    *Please post a description of the intended use case on the MLIR forum and
-    wait for confirmation.*
+        dialects;
+      - loop transformations in the SCF dialect.
+      
 
     ## Overview
 
@@ -79,6 +78,18 @@ def Transform_Dialect : Dialect {
     expected to have the `PossibleTopLevelTransformOpTrait` and may be used
     without arguments.
 
+    A program transformation expressed using the Transform dialect can be
+    programmatically triggered by calling:
+
+    ```c++
+    LogicalResult transform::applyTransforms(Operation *payloadRoot,
+                                             TransformOpInterface transform,
+                                             const TransformOptions &options);
+    ```
+
+    that applies the transformations specified by the top-level `transform` to
+    payload IR contained in `payloadRoot`.
+
     ## Dialect Extension Mechanism
 
     This dialect is designed to be extensible, that is, clients of this dialect
index 25f61d6..2c81985 100644 (file)
@@ -206,6 +206,16 @@ private:
   bool expensiveChecksEnabled = true;
 };
 
+/// Entry point to the Transform dialect infrastructure. Applies the
+/// transformation specified by `transform` to payload IR contained in
+/// `payloadRoot`. The `transform` operation may contain other operations that
+/// will be executed following the internal logic of the operation. It must
+/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
+/// This function internally keeps track of the transformation state.
+LogicalResult
+applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
+                const TransformOptions &options = TransformOptions());
+
 /// The state maintained across applications of various ops implementing the
 /// TransformOpInterface. The operations implementing this interface and the
 /// surrounding structure are referred to as transform IR. The operations to
@@ -250,15 +260,11 @@ class TransformState {
     TransformOpReverseMapping reverse;
   };
 
-public:
-  /// Creates a state for transform ops living in the given region. The parent
-  /// operation of the region. The second argument points to the root operation
-  /// in the payload IR being transformed, which may or may not contain the
-  /// region with transform ops. Additional options can be provided through the
-  /// trailing configuration object.
-  TransformState(Region &region, Operation *root,
-                 const TransformOptions &options = TransformOptions());
+  friend LogicalResult applyTransforms(Operation *payloadRoot,
+                                       TransformOpInterface transform,
+                                       const TransformOptions &options);
 
+public:
   /// Returns the op at which the transformation state is rooted. This is
   /// typically helpful for transformations that apply globally.
   Operation *getTopLevel() const;
@@ -438,6 +444,13 @@ private:
   /// Identifier for storing top-level value in the `operations` mapping.
   static constexpr Value kTopLevelValue = Value();
 
+  /// Creates a state for transform ops living in the given region. The second
+  /// argument points to the root operation in the payload IR being transformed,
+  /// which may or may not contain the region with transform ops. Additional
+  /// options can be provided through the trailing configuration object.
+  TransformState(Region *region, Operation *payloadRoot,
+                 const TransformOptions &options = TransformOptions());
+
   /// Returns the mappings frame for the reigon in which the value is defined.
   const Mappings &getMapping(Value value) const {
     return const_cast<TransformState *>(this)->getMapping(value);
index 60414e8..2810444 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "transform-dialect"
 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
@@ -25,14 +26,15 @@ using namespace mlir;
 
 constexpr const Value transform::TransformState::kTopLevelValue;
 
-transform::TransformState::TransformState(Region &region, Operation *root,
+transform::TransformState::TransformState(Region *region,
+                                          Operation *payloadRoot,
                                           const TransformOptions &options)
-    : topLevel(root), options(options) {
-  auto result = mappings.try_emplace(&region);
+    : topLevel(payloadRoot), options(options) {
+  auto result = mappings.try_emplace(region);
   assert(result.second && "the region scope is already present");
   (void)result;
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  regionStack.push_back(&region);
+  regionStack.push_back(region);
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 }
 
@@ -448,6 +450,27 @@ void transform::onlyReadsPayload(
 }
 
 //===----------------------------------------------------------------------===//
+// Entry point.
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::applyTransforms(Operation *payloadRoot,
+                                         TransformOpInterface transform,
+                                         const TransformOptions &options) {
+#ifndef NDEBUG
+  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
+      transform->getNumOperands() != 0) {
+    transform->emitError()
+        << "expected transform to start at the top-level transform op";
+    llvm::report_fatal_error("could not run transforms",
+                             /*gen_crash_diag=*/false);
+  }
+#endif // NDEBUG
+
+  TransformState state(transform->getParentRegion(), payloadRoot, options);
+  return state.applyTransform(transform).checkAndReport();
+}
+
+//===----------------------------------------------------------------------===//
 // Generated interface implementation.
 //===----------------------------------------------------------------------===//
 
index 57cbcac..735491a 100644 (file)
@@ -1,29 +1,41 @@
 // RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
 
-// expected-remark @below {{applying transformation}}
-transform.test_transform_op
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-remark @below {{applying transformation}}
+  transform.test_transform_op
+}
 
 // -----
 
-%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
-// expected-remark @below {{succeeded}}
-transform.test_consume_operand_if_matches_param_or_fail %0[42]
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+  // expected-remark @below {{succeeded}}
+  transform.test_consume_operand_if_matches_param_or_fail %0[42]
+}
 
 // -----
 
-%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
-// expected-error @below {{expected the operand to be associated with 21 got 42}}
-transform.test_consume_operand_if_matches_param_or_fail %0[21]
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+  // expected-error @below {{expected the operand to be associated with 21 got 42}}
+  transform.test_consume_operand_if_matches_param_or_fail %0[21]
+}
 
 // -----
 
 // It is okay to have multiple handles to the same payload op as long
 // as only one of them is consumed. The expensive checks mode is necessary
 // to detect double-consumption.
-%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
-%1 = transform.test_copy_payload %0
-// expected-remark @below {{succeeded}}
-transform.test_consume_operand_if_matches_param_or_fail %0[42]
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+  %1 = transform.test_copy_payload %0
+  // expected-remark @below {{succeeded}}
+  transform.test_consume_operand_if_matches_param_or_fail %0[42]
+}
 
 // -----
 
index ad5dcab..1696cae 100644 (file)
@@ -41,13 +41,12 @@ public:
 
   void runOnOperation() override {
     ModuleOp module = getOperation();
-    transform::TransformState state(
-        module.getBodyRegion(), module,
-        transform::TransformOptions().enableExpensiveChecks(
-            enableExpensiveChecks));
     for (auto op :
          module.getBody()->getOps<transform::TransformOpInterface>()) {
-      if (failed(state.applyTransform(op).checkAndReport()))
+      if (failed(transform::applyTransforms(
+              module, op,
+              transform::TransformOptions().enableExpensiveChecks(
+                  enableExpensiveChecks))))
         return signalPassFailure();
     }
   }