[MLIR][Transform] Consolidate the transform ops of get_parent_for and loop unroll...
authorAmy Wang <kai.ting.wang@huawei.com>
Wed, 30 Nov 2022 15:59:13 +0000 (10:59 -0500)
committerPrabhdeep Singh Soni <prabhdeep.singh.soni3@huawei.com>
Wed, 30 Nov 2022 16:07:44 +0000 (11:07 -0500)
This patch consolidates the two transform ops from the affine dialect
and the scf dialect to avoid code duplication.

This is to address the review comments from
https://reviews.llvm.org/D137997.

The transform ops directory / file structure for the affine dialect is
kept for the purpose of forth-coming transform ops
for affine, but get_parent_for and unroll are removed.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D138980

mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/test/Dialect/Affine/transform-ops.mlir [deleted file]
mlir/test/Dialect/SCF/transform-ops.mlir

index dc59435..e2b7e50 100644 (file)
@@ -18,61 +18,4 @@ include "mlir/IR/OpBase.td"
 
 def Transform_AffineForOp : Transform_ConcreteOpType<"affine.for">;
 
-def AffineGetParentForOp : Op<Transform_Dialect, "affine.get_parent_for", [
-  NavigationTransformOpTrait, MemoryEffectsOpInterface,
-  DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let summary =
-      "Gets a handle to the parent 'affine.for' loop of the given operation";
-  let description = [{
-    Produces a handle to the n-th (default 1) parent `affine.for` loop for each
-    Payload IR operation associated with the operand. Fails if such a loop
-    cannot be found. The list of operations associated with the handle contains
-    parent operations in the same order as the list associated with the operand,
-    except for operations that are parents to more than one input which are only
-    present once.
-  }];
-
-  let arguments =
-    (ins TransformTypeInterface:$target,
-         DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
-                           "1">:$num_loops);
-  let results = (outs TransformTypeInterface:$parent);
-
-  let assemblyFormat =
-      "$target attr-dict `:` functional-type(operands, results)";
-}
-
-
-def AffineLoopUnrollOp : Op<Transform_Dialect, "affine.unroll", [
-  FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-  TransformOpInterface, TransformEachOpTrait]> {
-  let summary = "Unrolls the given loop with the given unroll factor";
-  let description = [{
-    Unrolls each loop associated with the given handle to have up to the given
-    number of loop body copies per iteration. If the unroll factor is larger
-    than the loop trip count, the latter is used as the unroll factor instead.
-
-    #### Return modes
-
-    This operation ignores non-affine::For ops and drops them in the return.
-    If all the operations referred to by the `target` PDLOperation unroll
-    properly, the transform succeeds. Otherwise the transform silently fails.
-
-    Does not return handles as the operation may result in the loop being
-    removed after a full unrolling.
-  }];
-
-  let arguments = (ins Transform_AffineForOp:$target,
-                       ConfinedAttr<I64Attr, [IntPositive]>:$factor);
-
-  let assemblyFormat = "$target attr-dict `:` type($target)";
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::AffineForOp target,
-        ::llvm::SmallVector<::mlir::Operation *> & results,
-        ::mlir::transform::TransformState & state);
-  }];
-}
-
 #endif // Affine_TRANSFORM_OPS
index d1c5c59..59d25da 100644 (file)
@@ -23,19 +23,20 @@ def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Gets a handle to the parent 'for' loop of the given operation";
   let description = [{
-    Produces a handle to the n-th (default 1) parent `scf.for` loop for each
-    Payload IR operation associated with the operand. Fails if such a loop
-    cannot be found. The list of operations associated with the handle contains
-    parent operations in the same order as the list associated with the operand,
-    except for operations that are parents to more than one input which are only
-    present once.
+    Produces a handle to the n-th (default 1) parent `scf.for` or `affine.for`
+    (when the affine flag is true) loop for each Payload IR operation
+    associated with the operand. Fails if such a loop cannot be found. The list
+    of operations associated with the handle contains parent operations in the
+    same order as the list associated with the operand, except for operations
+    that are parents to more than one input which are only present once.
   }];
 
   let arguments =
     (ins TransformTypeInterface:$target,
          DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
-                           "1">:$num_loops);
-  let results = (outs TransformTypeInterface:$parent);
+                           "1">:$num_loops,
+         DefaultValuedAttr<BoolAttr, "false">:$affine);
+  let results = (outs TransformTypeInterface : $parent);
 
   let assemblyFormat =
     "$target attr-dict `:` functional-type(operands, results)";
@@ -166,22 +167,23 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
 
     #### Return modes
 
-    This operation ignores non-scf::For ops and drops them in the return.
-    If all the operations referred to by the `target` PDLOperation unroll
-    properly, the transform succeeds. Otherwise the transform silently fails.
+    This operation ignores non-scf::For, non-affine::For ops and drops them in
+    the return.  If all the operations referred to by the `target` PDLOperation
+    unroll properly, the transform succeeds. Otherwise the transform silently
+    fails.
 
     Does not return handles as the operation may result in the loop being
     removed after a full unrolling.
   }];
 
-  let arguments = (ins Transform_ScfForOp:$target,
+  let arguments = (ins TransformTypeInterface:$target,
                        ConfinedAttr<I64Attr, [IntPositive]>:$factor);
 
   let assemblyFormat = "$target attr-dict `:` type($target)";
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::scf::ForOp target,
+        ::mlir::Operation *target,
         ::llvm::SmallVector<::mlir::Operation *> &results,
         ::mlir::transform::TransformState &state);
   }];
index 7c32166..605c07f 100644 (file)
@@ -23,52 +23,6 @@ public:
 } // namespace
 
 //===----------------------------------------------------------------------===//
-// AffineGetParentForOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::AffineGetParentForOp::apply(transform::TransformResults &results,
-                                       transform::TransformState &state) {
-  SetVector<Operation *> parents;
-  for (Operation *target : state.getPayloadOps(getTarget())) {
-    AffineForOp loop;
-    Operation *current = target;
-    for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
-      loop = current->getParentOfType<AffineForOp>();
-      if (!loop) {
-        DiagnosedSilenceableFailure diag = emitSilenceableError()
-                                           << "could not find an '"
-                                           << AffineForOp::getOperationName()
-                                           << "' parent";
-        diag.attachNote(target->getLoc()) << "target op";
-        results.set(getResult().cast<OpResult>(), {});
-        return diag;
-      }
-      current = loop;
-    }
-    parents.insert(loop);
-  }
-  results.set(getResult().cast<OpResult>(), parents.getArrayRef());
-  return DiagnosedSilenceableFailure::success();
-}
-
-//===----------------------------------------------------------------------===//
-// LoopUnrollOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::AffineLoopUnrollOp::applyToOne(AffineForOp target,
-                                          SmallVector<Operation *> &results,
-                                          transform::TransformState &state) {
-  if (failed(loopUnrollByFactor(target, getFactor()))) {
-    Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
-    diag << "op failed to unroll";
-    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
-  }
-  return DiagnosedSilenceableFailure(success());
-}
-
-//===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
 
index ec8516c..ca05715 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -30,21 +31,23 @@ public:
 //===----------------------------------------------------------------------===//
 // GetParentForOp
 //===----------------------------------------------------------------------===//
-
 DiagnosedSilenceableFailure
 transform::GetParentForOp::apply(transform::TransformResults &results,
                                  transform::TransformState &state) {
   SetVector<Operation *> parents;
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    scf::ForOp loop;
-    Operation *current = target;
+    Operation *loop, *current = target;
     for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
-      loop = current->getParentOfType<scf::ForOp>();
+      loop = getAffine() ? current->getParentOfType<AffineForOp>()
+                         : current->getParentOfType<scf::ForOp>();
+
       if (!loop) {
-        DiagnosedSilenceableFailure diag = emitSilenceableError()
-                                           << "could not find an '"
-                                           << scf::ForOp::getOperationName()
-                                           << "' parent";
+        DiagnosedSilenceableFailure diag =
+            emitSilenceableError()
+            << "could not find an '"
+            << (getAffine() ? AffineForOp::getOperationName()
+                            : scf::ForOp::getOperationName())
+            << "' parent";
         diag.attachNote(target->getLoc()) << "target op";
         results.set(getResult().cast<OpResult>(), {});
         return diag;
@@ -215,12 +218,18 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::LoopUnrollOp::applyToOne(scf::ForOp target,
+transform::LoopUnrollOp::applyToOne(Operation *op,
                                     SmallVector<Operation *> &results,
                                     transform::TransformState &state) {
-  if (failed(loopUnrollByFactor(target, getFactor()))) {
-    Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
-    diag << "op failed to unroll";
+  LogicalResult result(failure());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+    result = loopUnrollByFactor(scfFor, getFactor());
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = loopUnrollByFactor(affineFor, getFactor());
+
+  if (failed(result)) {
+    Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note);
+    diag << "Op failed to unroll";
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
   return DiagnosedSilenceableFailure(success());
diff --git a/mlir/test/Dialect/Affine/transform-ops.mlir b/mlir/test/Dialect/Affine/transform-ops.mlir
deleted file mode 100644 (file)
index 0a12209..0000000
+++ /dev/null
@@ -1,67 +0,0 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
-
-// CHECK-LABEL: @get_parent_for_op
-func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
-  // expected-remark @below {{first loop}}
-  affine.for %i = %arg0 to %arg1 {
-    // expected-remark @below {{second loop}}
-    affine.for %j = %arg0 to %arg1 {
-      // expected-remark @below {{third loop}}
-      affine.for %k = %arg0 to %arg1 {
-        arith.addi %i, %j : index
-      }
-    }
-  }
-  return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
-  // CHECK: = transform.affine.get_parent_for
-  %1 = transform.affine.get_parent_for %0 : (!pdl.operation) -> !transform.op<"affine.for">
-  %2 = transform.affine.get_parent_for %0 { num_loops = 2 } : (!pdl.operation) -> !transform.op<"affine.for">
-  %3 = transform.affine.get_parent_for %0 { num_loops = 3 } : (!pdl.operation) -> !transform.op<"affine.for">
-  transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"affine.for">
-  transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"affine.for">
-  transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"affine.for">
-}
-
-// -----
-
-func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
-  // expected-note @below {{target op}}
-  arith.addi %arg0, %arg1 : index
-  return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
-  // expected-error @below {{could not find an 'affine.for' parent}}
-  %1 = transform.affine.get_parent_for %0 : (!pdl.operation) -> !transform.op<"affine.for">
-}
-
-// -----
-
-func.func @loop_unroll_op() {
-  %c0 = arith.constant 0 : index
-  %c42 = arith.constant 42 : index
-  %c5 = arith.constant 5 : index
-  // CHECK: affine.for %[[I:.+]] =
-  // expected-remark @below {{affine for loop}}
-  affine.for %i = %c0 to %c42 {
-    // CHECK-COUNT-4: arith.addi
-    arith.addi %i, %i : index
-  }
-  return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
-  %1 = transform.affine.get_parent_for %0 : (!pdl.operation) -> !transform.op<"affine.for">
-  transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for">
-  transform.affine.unroll %1 { factor = 4 } : !transform.op<"affine.for">
-}
-
index baca3c8..d6ff2f2 100644 (file)
@@ -192,3 +192,94 @@ transform.sequence failures(propagate) {
   transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for">
 }
 
+// -----
+
+// CHECK-LABEL: @get_parent_for_op
+func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
+  // expected-remark @below {{first loop}}
+  affine.for %i = %arg0 to %arg1 {
+    // expected-remark @below {{second loop}}
+    affine.for %j = %arg0 to %arg1 {
+      // expected-remark @below {{third loop}}
+      affine.for %k = %arg0 to %arg1 {
+        arith.addi %i, %j : index
+      }
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+  // CHECK: = transform.loop.get_parent_for
+  %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  %2 = transform.loop.get_parent_for %0 { num_loops = 2, affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  %3 = transform.loop.get_parent_for %0 { num_loops = 3, affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"affine.for">
+}
+
+// -----
+
+func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
+  // expected-note @below {{target op}}
+  arith.addi %arg0, %arg1 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+  // expected-error @below {{could not find an 'affine.for' parent}}
+  %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+}
+
+// -----
+
+func.func @loop_unroll_op() {
+  %c0 = arith.constant 0 : index
+  %c42 = arith.constant 42 : index
+  %c5 = arith.constant 5 : index
+  // CHECK: affine.for %[[I:.+]] =
+  // expected-remark @below {{affine for loop}}
+  affine.for %i = %c0 to %c42 {
+    // CHECK-COUNT-4: arith.addi
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+  %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for">
+  transform.loop.unroll %1 { factor = 4, affine = true } : !transform.op<"affine.for">
+}
+
+// -----
+
+func.func @test_mixed_loops() {
+  %c0 = arith.constant 0 : index
+  %c42 = arith.constant 42 : index
+  %c5 = arith.constant 5 : index
+  scf.for %j = %c0 to %c42 step %c5 {
+    // CHECK: affine.for %[[I:.+]] =
+    // expected-remark @below {{affine for loop}}
+    affine.for %i = %c0 to %c42 {
+      // CHECK-COUNT-4: arith.addi
+      arith.addi %i, %i : index
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+  %1 = transform.loop.get_parent_for %0 { num_loops = 1, affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for">
+  transform.loop.unroll %1 { factor = 4 } : !transform.op<"affine.for">
+}