[mlir][PDL] Infer result types from a `replace` as the last resort
authorRiver Riddle <riddleriver@gmail.com>
Thu, 1 Sep 2022 17:36:00 +0000 (10:36 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 8 Sep 2022 18:35:31 +0000 (11:35 -0700)
This prevents situations where explicit results types were provided,
which have different types than the operation being replaced. This
is useful for supporting dialect conversion, which will have proper
support added in a followup.

Differential Revision: https://reviews.llvm.org/D133141

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir

index 42ea3c9..e4cce8d 100644 (file)
@@ -835,28 +835,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
     pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
     SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
     bool &hasInferredResultTypes) {
-  // Look for an operation that was replaced by `op`. The result types will be
-  // inferred from the results that were replaced.
   Block *rewriterBlock = op->getBlock();
-  for (OpOperand &use : op.op().getUses()) {
-    // Check that the use corresponds to a ReplaceOp and that it is the
-    // replacement value, not the operation being replaced.
-    pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
-    if (!replOpUser || use.getOperandNumber() == 0)
-      continue;
-    // Make sure the replaced operation was defined before this one.
-    Value replOpVal = replOpUser.operation();
-    Operation *replacedOp = replOpVal.getDefiningOp();
-    if (replacedOp->getBlock() == rewriterBlock &&
-        !replacedOp->isBeforeInBlock(op))
-      continue;
-
-    Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
-        replacedOp->getLoc(), mapRewriteValue(replOpVal));
-    types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
-        replacedOp->getLoc(), replacedOpResults));
-    return;
-  }
 
   // Try to handle resolution for each of the result types individually. This is
   // preferred over type inferrence because it will allow for us to use existing
@@ -895,6 +874,31 @@ void PatternLowering::generateOperationResultTypeRewriter(
     return;
   }
 
+  // Look for an operation that was replaced by `op`. The result types will be
+  // inferred from the results that were replaced.
+  for (OpOperand &use : op.op().getUses()) {
+    // Check that the use corresponds to a ReplaceOp and that it is the
+    // replacement value, not the operation being replaced.
+    pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
+    if (!replOpUser || use.getOperandNumber() == 0)
+      continue;
+    // Make sure the replaced operation was defined before this one. PDL
+    // rewrites only have single block regions, so if the op isn't in the
+    // rewriter block (i.e. the current block of the operation) we already know
+    // it dominates (i.e. it's in the matcher).
+    Value replOpVal = replOpUser.operation();
+    Operation *replacedOp = replOpVal.getDefiningOp();
+    if (replacedOp->getBlock() == rewriterBlock &&
+        !replacedOp->isBeforeInBlock(op))
+      continue;
+
+    Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
+        replacedOp->getLoc(), mapRewriteValue(replOpVal));
+    types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
+        replacedOp->getLoc(), replacedOpResults));
+    return;
+  }
+
   // If the types could not be inferred from any context and there weren't any
   // explicit result types, assume the user actually meant for the operation to
   // have no results.
index 8d18596..d6e8f4a 100644 (file)
@@ -86,7 +86,7 @@ module @operation_infer_types_from_replaceop {
     %root = operation "foo.op" -> (%rootType, %rootType1 : !pdl.type, !pdl.type)
     rewrite %root {
       %newType1 = type
-      %newOp = operation "foo.op" -> (%rootType, %newType1 : !pdl.type, !pdl.type)
+      %newOp = operation "foo.op"
       replace %root with %newOp
     }
   }