From 493459b6dd28e4cb7414879a507f641b0414f3e4 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 6 Jan 2023 23:03:12 +0000 Subject: [PATCH] [mlir][spirv] Add folder for LogicalNotEqual Add a folder for LogicalNotEqual when rhs is false. This pattern shows up after lowering to SPIRV. Differential Revision: https://reviews.llvm.org/D141163 --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td | 1 + mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp | 17 +++++++++++++++++ mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir | 16 ++++++++++++++++ 3 files changed, 34 insertions(+) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index 0abe523..93c9d75 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -723,6 +723,7 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual", %2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1> ``` }]; + let hasFolder = true; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index bca91e5..e7d212b 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -252,6 +252,23 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// +// spirv.LogicalNotEqualOp +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::LogicalNotEqualOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && + "spirv.LogicalNotEqual should take two operands"); + + if (Optional rhs = getScalarOrSplatBoolAttr(operands.back())) { + // x && false = x + if (!rhs.value()) + return getOperand1(); + } + + return Attribute(); +} + +//===----------------------------------------------------------------------===// // spirv.LogicalNot //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 518ad2e..f543ed4 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -470,6 +470,22 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector< spirv.ReturnValue %3 : vector<3xi1> } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.LogicalNotEqual +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @convert_logical_not_equal_false +// CHECK-SAME: %[[ARG:.+]]: vector<4xi1> +func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> { + %cst = spirv.Constant dense : vector<4xi1> + // CHECK: spirv.ReturnValue %[[ARG]] : vector<4xi1> + %0 = spirv.LogicalNotEqual %arg, %cst : vector<4xi1> + spirv.ReturnValue %0 : vector<4xi1> +} + // ----- func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> { -- 2.7.4