namespace {
-// Combine chained `spirv::AccessChainOp` operations into one
-// `spirv::AccessChainOp` operation.
+/// Combines chained `spirv::AccessChainOp` operations into one
+/// `spirv::AccessChainOp` operation.
struct CombineChainedAccessChain
: public OpRewritePattern<spirv::AccessChainOp> {
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
return matchSuccess();
}
};
-} // namespace
+} // end anonymous namespace
void spirv::AccessChainOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
return success();
}
+namespace {
+
+/// Converts chained `spirv::BitcastOp` operations into one
+/// `spirv::BitcastOp` operation.
+struct ConvertChainedBitcast : public OpRewritePattern<spirv::BitcastOp> {
+ using OpRewritePattern<spirv::BitcastOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(spirv::BitcastOp bitcastOp,
+ PatternRewriter &rewriter) const override {
+ auto parentBitcastOp = dyn_cast_or_null<spirv::BitcastOp>(
+ bitcastOp.operand()->getDefiningOp());
+
+ if (!parentBitcastOp) {
+ return matchFailure();
+ }
+
+ rewriter.replaceOpWithNewOp<spirv::BitcastOp>(
+ /*valuesToRemoveIfDead=*/{parentBitcastOp.result()}, bitcastOp,
+ bitcastOp.result()->getType(), parentBitcastOp.operand());
+ return matchSuccess();
+ }
+};
+} // end anonymous namespace
+
+void spirv::BitcastOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ConvertChainedBitcast>(context);
+}
+
//===----------------------------------------------------------------------===//
// spv.BitFieldInsert
//===----------------------------------------------------------------------===//
return matchSuccess();
}
-} // namespace
+} // end anonymous namespace
void spirv::SelectionOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
// -----
//===----------------------------------------------------------------------===//
+// spv.Bitcast
+//===----------------------------------------------------------------------===//
+
+func @convert_bitcast_full(%arg0 : vector<2xf32>) -> f64 {
+ // CHECK: %[[RESULT:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
+ // CHECK-NEXT: spv.ReturnValue %[[RESULT]]
+ %0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32>
+ %1 = spv.Bitcast %0 : vector<2xi32> to i64
+ %2 = spv.Bitcast %1 : i64 to f64
+ spv.ReturnValue %2 : f64
+}
+
+// -----
+
+func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spv.ptr<i64, Uniform>) -> f64 {
+ // CHECK: %[[RESULT_0:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to i64
+ // CHECK-NEXT: %[[RESULT_1:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
+ // CHECK-NEXT: spv.Store {{".*"}} {{%.*}}, %[[RESULT_0]]
+ // CHECK-NEXT: spv.ReturnValue %[[RESULT_1]]
+ %0 = spv.Bitcast %arg0 : vector<2xf32> to i64
+ %1 = spv.Bitcast %0 : i64 to f64
+ spv.Store "Uniform" %arg1, %0 : i64
+ spv.ReturnValue %1 : f64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//