[mlir][StandardToSPIRV] Handle i1 case for lowering std.zexti to SPIR-V.
authorHanhan Wang <hanchung@google.com>
Wed, 3 Jun 2020 22:00:33 +0000 (15:00 -0700)
committerHanhan Wang <hanchung@google.com>
Wed, 3 Jun 2020 22:01:18 +0000 (15:01 -0700)
Differential Revision: https://reviews.llvm.org/D80965

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

index facdbf7..5f960aa 100644 (file)
@@ -445,6 +445,28 @@ public:
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts std.zexti to spv.Select if the type of source is i1.
+class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
+public:
+  using SPIRVOpLowering<ZeroExtendIOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = operands.front().getType();
+    if (!srcType.isSignlessInteger() || srcType.getIntOrFloatBitWidth() != 1)
+      return failure();
+
+    auto dstType = this->typeConverter.convertType(op.getResult().getType());
+    Location loc = op.getLoc();
+    Value zero = rewriter.create<ConstantIntOp>(loc, 0, dstType);
+    Value one = rewriter.create<ConstantIntOp>(loc, 1, dstType);
+    rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
+        op, dstType, operands.front(), one, zero);
+    return success();
+  }
+};
+
 /// Converts type-casting standard operations to SPIR-V operations.
 template <typename StdOp, typename SPIRVOp>
 class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
@@ -455,9 +477,12 @@ public:
   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     assert(operands.size() == 1);
+    auto srcType = operands.front().getType();
+    if (srcType.isSignlessInteger() && srcType.getIntOrFloatBitWidth() == 1)
+      return failure();
     auto dstType =
         this->typeConverter.convertType(operation.getResult().getType());
-    if (dstType == operands.front().getType()) {
+    if (dstType == srcType) {
       // Due to type conversion, we are seeing the same source and target type.
       // Then we can just erase this operation by forwarding its operand.
       rewriter.replaceOp(operation, operands.front());
@@ -1012,7 +1037,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
       CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern,
       ReturnOpPattern, SelectOpPattern, IntStoreOpPattern, StoreOpPattern,
-      TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
+      ZeroExtendI1Pattern, TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
       TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
       TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
       TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
index 3fe24d0..0948832 100644 (file)
@@ -564,6 +564,15 @@ func @zexti2(%arg0 : i32) -> i64 {
   return %0 : i64
 }
 
+// CHECK-LABEL: @zexti3
+func @zexti3(%arg0 : i1) -> i32 {
+  // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spv.constant 1 : i32
+  // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, i32
+  %0 = std.zexti %arg0 : i1 to i32
+  return %0 : i32
+}
+
 // CHECK-LABEL: @trunci1
 func @trunci1(%arg0 : i64) -> i16 {
   // CHECK: spv.SConvert %{{.*}} : i64 to i16