This allows for bitcast conversion to roundtrip.
Fixes: https://github.com/llvm/llvm-project/issues/58801
Reviewed By: antiagainst, Hardcode84, mravishankar
Differential Revision: https://reviews.llvm.org/D137459
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
- let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
// -----
// spirv.BitcastOp
//===----------------------------------------------------------------------===//
-void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<ConvertChainedBitcast>(context);
+OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
+ Value arg = getOperand();
+ if (getType() == arg.getType())
+ return arg;
+
+ // Look through nested bitcasts.
+ if (auto bitcast = arg.getDefiningOp<spirv::BitcastOp>()) {
+ Value nestedArg = bitcast.getOperand();
+ if (nestedArg.getType() == getType())
+ return nestedArg;
+
+ getOperandMutable().assign(nestedArg);
+ return getResult();
+ }
+
+ // TODO(kuhar): Consider constant-folding the operand attribute.
+ return getResult();
}
//===----------------------------------------------------------------------===//
include "mlir/Dialect/SPIRV/IR/SPIRVOps.td"
//===----------------------------------------------------------------------===//
-// spirv.Bitcast
-//===----------------------------------------------------------------------===//
-
-def ConvertChainedBitcast : Pat<(SPIRV_BitcastOp (SPIRV_BitcastOp $operand)),
- (SPIRV_BitcastOp $operand)>;
-
-//===----------------------------------------------------------------------===//
// spirv.LogicalNot
//===----------------------------------------------------------------------===//
// -----
+// CHECK-LABEL: @convert_bitcast_roundtip
+// CHECK-SAME: %[[ARG:.+]]: i64
+func.func @convert_bitcast_roundtip(%arg0 : i64) -> i64 {
+ // CHECK: spirv.ReturnValue %[[ARG]]
+ %0 = spirv.Bitcast %arg0 : i64 to f64
+ %1 = spirv.Bitcast %0 : f64 to i64
+ spirv.ReturnValue %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @convert_bitcast_chained_roundtip
+// CHECK-SAME: %[[ARG:.+]]: i64
+func.func @convert_bitcast_chained_roundtip(%arg0 : i64) -> i64 {
+ // CHECK: spirv.ReturnValue %[[ARG]]
+ %0 = spirv.Bitcast %arg0 : i64 to f64
+ %1 = spirv.Bitcast %0 : f64 to vector<2xi32>
+ %2 = spirv.Bitcast %1 : vector<2xi32> to vector<2xf32>
+ %3 = spirv.Bitcast %2 : vector<2xf32> to i64
+ spirv.ReturnValue %3 : i64
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.CompositeExtract
//===----------------------------------------------------------------------===//