[mlir][StandardToSPIRV] Extend support for lowering cmpi to SPIRV.
authorHanhan Wang <hanchung@google.com>
Mon, 16 Nov 2020 14:50:45 +0000 (06:50 -0800)
committerHanhan Wang <hanchung@google.com>
Mon, 16 Nov 2020 14:51:05 +0000 (06:51 -0800)
The logic of vector on boolean was missed. This patch adds the logic and test on
it.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D91403

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

index 8736ad4..ac8b82d 100644 (file)
@@ -767,8 +767,7 @@ BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
   CmpIOpAdaptor cmpIOpOperands(operands);
 
   Type operandType = cmpIOp.lhs().getType();
-  if (!operandType.isa<IntegerType>() ||
-      operandType.cast<IntegerType>().getWidth() != 1)
+  if (!isBoolScalarOrVector(operandType))
     return failure();
 
   switch (cmpIOp.getPredicate()) {
@@ -794,8 +793,7 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
   CmpIOpAdaptor cmpIOpOperands(operands);
 
   Type operandType = cmpIOp.lhs().getType();
-  if (operandType.isa<IntegerType>() &&
-      operandType.cast<IntegerType>().getWidth() == 1)
+  if (isBoolScalarOrVector(operandType))
     return failure();
 
   switch (cmpIOp.getPredicate()) {
index 9f112bb..10e43ef 100644 (file)
@@ -327,6 +327,15 @@ func @boolcmpi(%arg0 : i1, %arg1 : i1) {
   return
 }
 
+// CHECK-LABEL: @vecboolcmpi
+func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+  // CHECK: spv.LogicalEqual
+  %0 = cmpi "eq", %arg0, %arg1 : vector<4xi1>
+  // CHECK: spv.LogicalNotEqual
+  %1 = cmpi "ne", %arg0, %arg1 : vector<4xi1>
+  return
+}
+
 } // end module
 
 // -----