From 7b7df8e85eec445389e4b07915f16aa18332719d Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Tue, 20 Apr 2021 07:34:32 -0700 Subject: [PATCH] [mlir][StandardToSPIRV] Add support for lowering std.xor on bool to SPIR-V std.xor ops on bool are lowered to spv.LogicalNotEqual. For Boolean values, xor and not-equal are the same thing. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D100817 --- .../Conversion/StandardToSPIRV/StandardToSPIRV.cpp | 29 +++++++++++++++++++++- .../StandardToSPIRV/std-ops-to-spirv.mlir | 4 +++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index 0196a21..2a6e7f2 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -663,6 +663,17 @@ public: ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector +/// of i1. +class BoolXOrOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(XOrOp xorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1250,6 +1261,22 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, return success(); } +LogicalResult +BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + assert(operands.size() == 2); + + if (!isBoolScalarOrVector(operands.front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(xorOp.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(xorOp, dstType, + operands); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -1293,7 +1320,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - SignedRemIOpPattern, XOrOpPattern, + SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index 0148a07..fe76948 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -224,6 +224,8 @@ func @logical_scalar(%arg0 : i1, %arg1 : i1) { %0 = and %arg0, %arg1 : i1 // CHECK: spv.LogicalOr %1 = or %arg0, %arg1 : i1 + // CHECK: spv.LogicalNotEqual + %2 = xor %arg0, %arg1 : i1 return } @@ -233,6 +235,8 @@ func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { %0 = and %arg0, %arg1 : vector<4xi1> // CHECK: spv.LogicalOr %1 = or %arg0, %arg1 : vector<4xi1> + // CHECK: spv.LogicalNotEqual + %2 = xor %arg0, %arg1 : vector<4xi1> return } -- 2.7.4