[mlir][spirv] Don't return value when cannot fold spirv.bitcast
authorLei Zhang <antiagainst@google.com>
Tue, 8 Nov 2022 00:11:46 +0000 (19:11 -0500)
committerLei Zhang <antiagainst@google.com>
Tue, 8 Nov 2022 00:11:46 +0000 (19:11 -0500)
Returing a value would make the canonicalization infrastructure
think that folding succeeded so the pattern will be tried again
when invoked via, e.g., `applyPatternsAndFoldGreedily` and
eventually fail due to not converging after 10 times by default.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D137598

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

index 57e6475..b068d23 100644 (file)
@@ -117,22 +117,22 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 
 OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
-  Value arg = getOperand();
-  if (getType() == arg.getType())
-    return arg;
+  Value curInput = getOperand();
+  if (getType() == curInput.getType())
+    return curInput;
 
   // Look through nested bitcasts.
-  if (auto bitcast = arg.getDefiningOp<spirv::BitcastOp>()) {
-    Value nestedArg = bitcast.getOperand();
-    if (nestedArg.getType() == getType())
-      return nestedArg;
+  if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
+    Value prevInput = prevCast.getOperand();
+    if (prevInput.getType() == getType())
+      return prevInput;
 
-    getOperandMutable().assign(nestedArg);
+    getOperandMutable().assign(prevInput);
     return getResult();
   }
 
   // TODO(kuhar): Consider constant-folding the operand attribute.
-  return getResult();
+  return {};
 }
 
 //===----------------------------------------------------------------------===//