[mlir][spirv] Add folder for LogicalNotEqual
authorThomas Raoux <thomasraoux@google.com>
Fri, 6 Jan 2023 23:03:12 +0000 (23:03 +0000)
committerThomas Raoux <thomasraoux@google.com>
Fri, 6 Jan 2023 23:13:57 +0000 (23:13 +0000)
Add a folder for LogicalNotEqual when rhs is false. This pattern shows
up after lowering to SPIRV.

Differential Revision: https://reviews.llvm.org/D141163

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

index 0abe523..93c9d75 100644 (file)
@@ -723,6 +723,7 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
     %2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
     ```
   }];
+  let hasFolder = true;
 }
 
 // -----
index bca91e5..e7d212b 100644 (file)
@@ -252,6 +252,23 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// spirv.LogicalNotEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::LogicalNotEqualOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 &&
+         "spirv.LogicalNotEqual should take two operands");
+
+  if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
+    // x && false = x
+    if (!rhs.value())
+      return getOperand1();
+  }
+
+  return Attribute();
+}
+
+//===----------------------------------------------------------------------===//
 // spirv.LogicalNot
 //===----------------------------------------------------------------------===//
 
index 518ad2e..f543ed4 100644 (file)
@@ -470,6 +470,22 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
   spirv.ReturnValue %3 : vector<3xi1>
 }
 
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.LogicalNotEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @convert_logical_not_equal_false
+// CHECK-SAME: %[[ARG:.+]]: vector<4xi1>
+func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
+  %cst = spirv.Constant dense<false> : vector<4xi1>
+  // CHECK: spirv.ReturnValue %[[ARG]] : vector<4xi1>
+  %0 = spirv.LogicalNotEqual %arg, %cst : vector<4xi1>
+  spirv.ReturnValue %0 : vector<4xi1>
+}
+
 // -----
 
 func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {