From: MaheshRavishankar Date: Wed, 29 Apr 2020 16:57:31 +0000 (-0700) Subject: [mlir][StandardToSPIRV] Handle conversion of cmpi operation with i1 X-Git-Tag: 2020.06-alpha~50^2~583 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1c12a95d9c52de8980ea9979350c7eabc1b9fd01;p=platform%2Fupstream%2Fllvm.git [mlir][StandardToSPIRV] Handle conversion of cmpi operation with i1 type operands. The instructions used to convert std.cmpi cannot have i1 types according to SPIR-V specification. A different set of operations are specified in the SPIR-V spec for comparing boolean types. Enhance the StandardToSPIRV lowering to target these instructions when operands to std.cmpi operation are of i1 type. Differential Revision: https://reviews.llvm.org/D79049 --- diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index b53128a33018..12b22cacdee2 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -184,6 +184,16 @@ public: ConversionPatternRewriter &rewriter) const override; }; +/// Converts integer compare operation on i1 type opearnds to SPIR-V ops. +class BoolCmpIOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts integer compare operation to SPIR-V ops. class CmpIOpPattern final : public SPIRVOpLowering { public: @@ -453,11 +463,43 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, // CmpIOp //===----------------------------------------------------------------------===// +LogicalResult +BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + CmpIOpOperandAdaptor cmpIOpOperands(operands); + + Type operandType = cmpIOp.lhs().getType(); + if (!operandType.isa() || + operandType.cast().getWidth() != 1) + return failure(); + + switch (cmpIOp.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ + cmpIOpOperands.lhs(), \ + cmpIOpOperands.rhs()); \ + return success(); + + DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); + DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp); + +#undef DISPATCH + default:; + } + return failure(); +} + LogicalResult CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpOperandAdaptor cmpIOpOperands(operands); + Type operandType = cmpIOp.lhs().getType(); + if (operandType.isa() && + operandType.cast().getWidth() == 1) + return failure(); + switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ @@ -599,9 +641,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, UnaryAndBinaryOpPattern, BitwiseOpPattern, BitwiseOpPattern, - ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern, - CmpIOpPattern, LoadOpPattern, ReturnOpPattern, SelectOpPattern, - StoreOpPattern, TypeCastingOpPattern, + BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, + CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern, + SelectOpPattern, StoreOpPattern, + TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, XOrOpPattern>( context, typeConverter); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index 6abdde44e3e5..e7ad95a1a173 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -285,6 +285,15 @@ func @cmpi(%arg0 : i32, %arg1 : i32) { return } +// CHECK-LABEL: @boolcmpi +func @boolcmpi(%arg0 : i1, %arg1 : i1) { + // CHECK: spv.LogicalEqual + %0 = cmpi "eq", %arg0, %arg1 : i1 + // CHECK: spv.LogicalNotEqual + %1 = cmpi "ne", %arg0, %arg1 : i1 + return +} + } // end module // -----