TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+ TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern
>(typeConverter, patterns.getContext());
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts std.br to spv.Branch.
+struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
+ using OpConversionPattern<BranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(BranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts std.cond_br to spv.BranchConditional.
+struct CondBranchOpPattern final : public OpConversionPattern<CondBranchOp> {
+ using OpConversionPattern<CondBranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CondBranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts tensor.extract into loading using access chains from SPIR-V local
/// variables.
class TensorExtractPattern final
return success();
}
+//===----------------------------------------------------------------------===//
+// BranchOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
+ adaptor.getDestOperands());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CondBranchOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult CondBranchOpPattern::matchAndRewrite(
+ CondBranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
+ op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
+ op.getFalseDest(), adaptor.getFalseDestOperands());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
spirv::UnaryAndBinaryOpPattern<MinSIOp, spirv::GLSLSMinOp>,
spirv::UnaryAndBinaryOpPattern<MinUIOp, spirv::GLSLUMinOp>,
- ReturnOpPattern, SelectOpPattern, SplatPattern>(typeConverter, context);
+ ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
+ CondBranchOpPattern>(typeConverter, context);
}
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
%splat = splat %f : vector<4xf32>
return %splat : vector<4xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std.br, std.cond_br
+//===----------------------------------------------------------------------===//
+
+module attributes {
+ spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+} {
+
+// CHECK-LABEL: func @simple_loop
+func @simple_loop(index, index, index) {
+^bb0(%begin : index, %end : index, %step : index):
+// CHECK-NEXT: spv.Branch ^bb1
+ br ^bb1
+
+// CHECK-NEXT: ^bb1: // pred: ^bb0
+// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
+^bb1: // pred: ^bb0
+ br ^bb2(%begin : index)
+
+// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
+// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
+// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
+^bb2(%0: index): // 2 preds: ^bb1, ^bb3
+ %1 = arith.cmpi slt, %0, %end : index
+ cond_br %1, ^bb3, ^bb4
+
+// CHECK: ^bb3: // pred: ^bb2
+// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
+// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
+^bb3: // pred: ^bb2
+ %2 = arith.addi %0, %step : index
+ br ^bb2(%2 : index)
+
+// CHECK: ^bb4: // pred: ^bb2
+^bb4: // pred: ^bb2
+ return
+}
+
+}