[MLIR][SPIRVToLLVM] Convert bitwise and logical not
authorGeorge Mitenkov <georgemitenk0v@gmail.com>
Mon, 29 Jun 2020 23:16:36 +0000 (19:16 -0400)
committerLei Zhang <antiagainst@google.com>
Mon, 29 Jun 2020 23:16:50 +0000 (19:16 -0400)
This patch introduces new conversion patterns for bit and logical
negation op: `spv.Not` and `spv.LogicalNot`. They are implemented
by applying xor on the operand and mask with all bits set.

Differential Revision: https://reviews.llvm.org/D82637

mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir

index 6178b53..83be1c5 100644 (file)
@@ -53,6 +53,16 @@ static unsigned getBitWidth(Type type) {
   return elementType.getIntOrFloatBitWidth();
 }
 
+/// Creates `IntegerAttribute` with all bits set for given type.
+IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
+  if (auto vecType = type.dyn_cast<VectorType>()) {
+    auto integerType = vecType.getElementType().cast<IntegerType>();
+    return builder.getIntegerAttr(integerType, -1);
+  }
+  auto integerType = type.cast<IntegerType>();
+  return builder.getIntegerAttr(integerType, -1);
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -154,6 +164,35 @@ public:
   }
 };
 
+/// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
+template <typename SPIRVOp>
+class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
+public:
+  using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto srcType = notOp.getType();
+    auto dstType = this->typeConverter.convertType(srcType);
+    if (!dstType)
+      return failure();
+
+    Location loc = notOp.getLoc();
+    IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
+    auto mask = srcType.template isa<VectorType>()
+                    ? rewriter.create<LLVM::ConstantOp>(
+                          loc, dstType,
+                          SplatElementsAttr::get(
+                              srcType.template cast<VectorType>(), minusOne))
+                    : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
+    rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
+                                                      notOp.operand(), mask);
+    return success();
+  }
+};
+
 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
 public:
   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
@@ -346,6 +385,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
+      NotPattern<spirv::NotOp>,
 
       // Cast ops
       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
@@ -386,6 +426,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
+      NotPattern<spirv::LogicalNotOp>,
 
       // Shift ops
       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
index fb276d4..434430d 100644 (file)
@@ -79,3 +79,21 @@ func @bitwise_xor_vector(%arg0: vector<2xi16>, %arg1: vector<2xi16>) {
        %0 = spv.BitwiseXor %arg0, %arg1 : vector<2xi16>
        return
 }
+
+//===----------------------------------------------------------------------===//
+// spv.Not
+//===----------------------------------------------------------------------===//
+
+func @not__scalar(%arg0: i32) {
+  // CHECK: %[[CONST:.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32
+  // CHECK: %{{.*}} = llvm.xor %{{.*}}, %[[CONST]] : !llvm.i32
+       %0 = spv.Not %arg0 : i32
+  return
+}
+
+func @not_vector(%arg0: vector<2xi16>) {
+  // CHECK: %[[CONST:.*]] = llvm.mlir.constant(dense<-1> : vector<2xi16>) : !llvm<"<2 x i16>">
+  // CHECK: %{{.*}} = llvm.xor %{{.*}}, %[[CONST]] : !llvm<"<2 x i16>">
+       %0 = spv.Not %arg0 : vector<2xi16>
+  return
+}
index a6ff260..e6f2ec2 100644 (file)
@@ -33,6 +33,24 @@ func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.LogicalNot
+//===----------------------------------------------------------------------===//
+
+func @logical_not__scalar(%arg0: i1) {
+  // CHECK: %[[CONST:.*]] = llvm.mlir.constant(true) : !llvm.i1
+  // CHECK: %{{.*}} = llvm.xor %{{.*}}, %[[CONST]] : !llvm.i1
+       %0 = spv.LogicalNot %arg0 : i1
+  return
+}
+
+func @logical_not_vector(%arg0: vector<4xi1>) {
+  // CHECK: %[[CONST:.*]] = llvm.mlir.constant(dense<true> : vector<4xi1>) : !llvm<"<4 x i1>">
+  // CHECK: %{{.*}} = llvm.xor %{{.*}}, %[[CONST]] : !llvm<"<4 x i1>">
+       %0 = spv.LogicalNot %arg0 : vector<4xi1>
+  return
+}
+
+//===----------------------------------------------------------------------===//
 // spv.LogicalAnd
 //===----------------------------------------------------------------------===//