[mlir] verify that transform ops have memory effects
authorAlex Zinenko <zinenko@google.com>
Tue, 10 Jan 2023 12:19:53 +0000 (12:19 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 10 Jan 2023 13:49:40 +0000 (13:49 +0000)
Add a verifier to the TransformOpInterface ensuring that operations
implementing the interface define memory effects on their operands and
results.

Add the missing effects to TileToForeachThreadOp, specifically for
operands that were added at a later version of the op without modifying
`getEffects` accordingly.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141371

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

index b2c3827..6dbd121 100644 (file)
@@ -494,6 +494,9 @@ mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
 
 /// Verification hook for PossibleTopLevelTransformOpTrait.
 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
+
+/// Verification hook for TransformOpInterface.
+LogicalResult verifyTransformOpInterface(Operation *op);
 } // namespace detail
 
 /// This trait is supposed to be attached to Transform dialect operations that
index b0b92da..f4d66c5 100644 (file)
@@ -101,6 +101,10 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
       return diag;
     }
   }];
+
+  let verify = [{
+    return ::mlir::transform::detail::verifyTransformOpInterface($_op);
+  }];
 }
 
 class TransformTypeInterfaceBase<string cppClass, string cppObjectType>
index f170d0b..5097396 100644 (file)
@@ -1760,6 +1760,8 @@ void transform::TileToForeachThreadOp::getEffects(
   consumesHandle(getTarget(), effects);
   onlyReadsHandle(getTileSizes(), effects);
   onlyReadsHandle(getNumThreads(), effects);
+  onlyReadsHandle(getPackedNumThreads(), effects);
+  onlyReadsHandle(getPackedTileSizes(), effects);
   producesHandle(getResults(), effects);
 }
 
index e2e8aa2..b8a4ee7 100644 (file)
@@ -616,8 +616,8 @@ void transform::consumesHandle(
 
 /// Returns `true` if the given list of effects instances contains an instance
 /// with the effect type specified as template parameter.
-template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource>
-static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
+template <typename EffectTy, typename ResourceTy, typename Range>
+static bool hasEffect(Range &&effects) {
   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
     return isa<EffectTy>(effect.getEffect()) &&
            isa<ResourceTy>(effect.getResource());
@@ -665,6 +665,48 @@ void transform::onlyReadsPayload(
 }
 
 //===----------------------------------------------------------------------===//
+// Utilities for TransformOpInterface.
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
+  auto iface = cast<MemoryEffectOpInterface>(op);
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+
+  auto effectsOn = [&](Value value) {
+    return llvm::make_filter_range(
+        effects, [value](const MemoryEffects::EffectInstance &instance) {
+          return instance.getValue() == value;
+        });
+  };
+
+  for (OpOperand &operand : op->getOpOperands()) {
+    auto range = effectsOn(operand.get());
+    if (range.empty()) {
+      InFlightDiagnostic diag =
+          op->emitError() << "TransformOpInterface requires memory effects "
+                             "on operands to be specified";
+      diag.attachNote() << "no effects specified for operand #"
+                        << operand.getOperandNumber();
+      return diag;
+    }
+  }
+  for (OpResult result : op->getResults()) {
+    auto range = effectsOn(result);
+    if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
+            range)) {
+      InFlightDiagnostic diag =
+          op->emitError() << "TransformOpInterface requires 'allocate' memory "
+                             "effect to be specified for results";
+      diag.attachNote() << "no 'allocate' effect specified for result #"
+                        << result.getResultNumber();
+      return diag;
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // Entry point.
 //===----------------------------------------------------------------------===//
 
index ec3f553..e957d7a 100644 (file)
@@ -210,3 +210,21 @@ transform.sequence failures(propagate) {
   // expected-note @below {{used here as operand #0}}
   transform.test_consume_operand %0
 }
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}}
+  // expected-note @below {{no effects specified for operand #0}}
+  transform.test_required_memory_effects %arg0 : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{TransformOpInterface requires 'allocate' memory effect to be specified for results}}
+  // expected-note @below {{no 'allocate' effect specified for result #0}}
+  transform.test_required_memory_effects %arg0 {has_operand_effect} : (!transform.any_op) -> !transform.any_op
+}
index 338d72e..63d6828 100644 (file)
@@ -471,7 +471,9 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
 }
 
 void mlir::test::TestProduceNullParamOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::producesHandle(getOut(), effects);
+}
 
 DiagnosedSilenceableFailure
 mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
@@ -480,6 +482,23 @@ mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  if (getHasOperandEffect())
+    transform::consumesHandle(getIn(), effects);
+
+  if (getHasResultEffect())
+    transform::producesHandle(getOut(), effects);
+  else
+    transform::onlyReadsHandle(getOut(), effects);
+}
+
+DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  results.set(getOut().cast<OpResult>(), state.getPayloadOps(getIn()));
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL
index 9ff5e30..02e8a69 100644 (file)
@@ -352,4 +352,16 @@ def TestProduceNullParamOp
   let cppNamespace = "::mlir::test";
 }
 
+def TestRequiredMemoryEffectsOp
+  : Op<Transform_Dialect, "test_required_memory_effects",
+      [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+       DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins TransformHandleTypeInterface:$in,
+                       UnitAttr:$has_operand_effect,
+                       UnitAttr:$has_result_effect);
+  let results = (outs TransformHandleTypeInterface:$out);
+  let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD