[mlir][spirv] Add conversions from arith.bitcast, std.br, std.cond_br to spirv.
authorxndcn <xndchn@gmail.com>
Fri, 29 Oct 2021 14:45:18 +0000 (22:45 +0800)
committerxndcn <xndchn@gmail.com>
Sat, 30 Oct 2021 16:40:35 +0000 (00:40 +0800)
Differential Revision: https://reviews.llvm.org/D112819

mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

index d1eed3231d6f429caf6d87b5917fe119ab7d6053..6fd69637df1d8584cc476012ae0a54e9a79dd034 100644 (file)
@@ -784,6 +784,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+    TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
     CmpFOpNanNonePattern, CmpFOpPattern
   >(typeConverter, patterns.getContext());
index 48b6e99257bb5764090af282a8e12a583a68653c..87d57080ed80b6c5e253034a810061f5a2fc6893 100644 (file)
@@ -65,6 +65,24 @@ public:
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts std.br to spv.Branch.
+struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
+  using OpConversionPattern<BranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(BranchOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts std.cond_br to spv.BranchConditional.
+struct CondBranchOpPattern final : public OpConversionPattern<CondBranchOp> {
+  using OpConversionPattern<CondBranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CondBranchOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts tensor.extract into loading using access chains from SPIR-V local
 /// variables.
 class TensorExtractPattern final
@@ -176,6 +194,31 @@ SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// BranchOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor,
+                                 ConversionPatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
+                                               adaptor.getDestOperands());
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CondBranchOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult CondBranchOpPattern::matchAndRewrite(
+    CondBranchOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
+      op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
+      op.getFalseDest(), adaptor.getFalseDestOperands());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern population
 //===----------------------------------------------------------------------===//
@@ -194,7 +237,8 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       spirv::UnaryAndBinaryOpPattern<MinSIOp, spirv::GLSLSMinOp>,
       spirv::UnaryAndBinaryOpPattern<MinUIOp, spirv::GLSLUMinOp>,
 
-      ReturnOpPattern, SelectOpPattern, SplatPattern>(typeConverter, context);
+      ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
+      CondBranchOpPattern>(typeConverter, context);
 }
 
 void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
index 0c2aee27ca5085bccb82be55f93853e3ced59e2f..8a41a90a2fc0b488f3a2011d0c619fcfb923356b 100644 (file)
@@ -572,6 +572,15 @@ func @index_cast4(%arg0: index) {
   return
 }
 
+// CHECK-LABEL: @bit_cast
+func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
+  // CHECK: spv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>
+  %0 = arith.bitcast %arg0 : vector<2xf32> to vector<2xi32>
+  // CHECK: spv.Bitcast %{{.+}} : i64 to f64
+  %1 = arith.bitcast %arg1 : i64 to f64
+  return
+}
+
 // CHECK-LABEL: @fpext1
 func @fpext1(%arg0: f16) -> f64 {
   // CHECK: spv.FConvert %{{.*}} : f16 to f64
index 36a6d793e72147005909fb4385f41da6b1a61e91..b8d9966c9a5bfd1b61eca96f25fdfed95480078e 100644 (file)
@@ -933,3 +933,45 @@ func @splat(%f : f32) -> vector<4xf32> {
   %splat = splat %f : vector<4xf32>
   return %splat : vector<4xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std.br, std.cond_br
+//===----------------------------------------------------------------------===//
+
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+} {
+
+// CHECK-LABEL: func @simple_loop
+func @simple_loop(index, index, index) {
+^bb0(%begin : index, %end : index, %step : index):
+// CHECK-NEXT:  spv.Branch ^bb1
+  br ^bb1
+
+// CHECK-NEXT: ^bb1:    // pred: ^bb0
+// CHECK-NEXT:  spv.Branch ^bb2({{.*}} : i32)
+^bb1:   // pred: ^bb0
+  br ^bb2(%begin : index)
+
+// CHECK:      ^bb2({{.*}}: i32):       // 2 preds: ^bb1, ^bb3
+// CHECK-NEXT:  {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
+// CHECK-NEXT:  spv.BranchConditional {{.*}}, ^bb3, ^bb4
+^bb2(%0: index):        // 2 preds: ^bb1, ^bb3
+  %1 = arith.cmpi slt, %0, %end : index
+  cond_br %1, ^bb3, ^bb4
+
+// CHECK:      ^bb3:    // pred: ^bb2
+// CHECK-NEXT:  {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
+// CHECK-NEXT:  spv.Branch ^bb2({{.*}} : i32)
+^bb3:   // pred: ^bb2
+  %2 = arith.addi %0, %step : index
+  br ^bb2(%2 : index)
+
+// CHECK:      ^bb4:    // pred: ^bb2
+^bb4:   // pred: ^bb2
+  return
+}
+
+}