[mlir][spirv] Fold noop `BitcastsOp`s
authorJakub Kuderski <kubak@google.com>
Fri, 4 Nov 2022 21:37:18 +0000 (17:37 -0400)
committerJakub Kuderski <kubak@google.com>
Fri, 4 Nov 2022 21:37:30 +0000 (17:37 -0400)
This allows for bitcast conversion to roundtrip.

Fixes: https://github.com/llvm/llvm-project/issues/58801

Reviewed By: antiagainst, Hardcode84, mravishankar

Differential Revision: https://reviews.llvm.org/D137459

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

index c985c6e..8975fa0 100644 (file)
@@ -88,7 +88,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
   let assemblyFormat = [{
     $operand attr-dict `:` type($operand) `to` type($result)
   }];
-  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 // -----
index b3444d8..57e6475 100644 (file)
@@ -116,9 +116,23 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
 // spirv.BitcastOp
 //===----------------------------------------------------------------------===//
 
-void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                   MLIRContext *context) {
-  results.add<ConvertChainedBitcast>(context);
+OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
+  Value arg = getOperand();
+  if (getType() == arg.getType())
+    return arg;
+
+  // Look through nested bitcasts.
+  if (auto bitcast = arg.getDefiningOp<spirv::BitcastOp>()) {
+    Value nestedArg = bitcast.getOperand();
+    if (nestedArg.getType() == getType())
+      return nestedArg;
+
+    getOperandMutable().assign(nestedArg);
+    return getResult();
+  }
+
+  // TODO(kuhar): Consider constant-folding the operand attribute.
+  return getResult();
 }
 
 //===----------------------------------------------------------------------===//
index 12c41fc..e8d2274 100644 (file)
@@ -14,13 +14,6 @@ include "mlir/IR/PatternBase.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVOps.td"
 
 //===----------------------------------------------------------------------===//
-// spirv.Bitcast
-//===----------------------------------------------------------------------===//
-
-def ConvertChainedBitcast : Pat<(SPIRV_BitcastOp (SPIRV_BitcastOp $operand)),
-                                (SPIRV_BitcastOp $operand)>;
-
-//===----------------------------------------------------------------------===//
 // spirv.LogicalNot
 //===----------------------------------------------------------------------===//
 
index b13d644..e65f92e 100644 (file)
@@ -86,6 +86,30 @@ func.func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spirv.ptr<i
 
 // -----
 
+// CHECK-LABEL: @convert_bitcast_roundtip
+// CHECK-SAME:    %[[ARG:.+]]: i64
+func.func @convert_bitcast_roundtip(%arg0 : i64) -> i64 {
+  // CHECK: spirv.ReturnValue %[[ARG]]
+  %0 = spirv.Bitcast %arg0 : i64 to f64
+  %1 = spirv.Bitcast %0 : f64 to i64
+  spirv.ReturnValue %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @convert_bitcast_chained_roundtip
+// CHECK-SAME:    %[[ARG:.+]]: i64
+func.func @convert_bitcast_chained_roundtip(%arg0 : i64) -> i64 {
+  // CHECK: spirv.ReturnValue %[[ARG]]
+  %0 = spirv.Bitcast %arg0 : i64 to f64
+  %1 = spirv.Bitcast %0 : f64 to vector<2xi32>
+  %2 = spirv.Bitcast %1 : vector<2xi32> to vector<2xf32>
+  %3 = spirv.Bitcast %2 : vector<2xf32> to i64
+  spirv.ReturnValue %3 : i64
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.CompositeExtract
 //===----------------------------------------------------------------------===//