From: George Mitenkov Date: Thu, 2 Jul 2020 18:21:35 +0000 (-0400) Subject: [MLIR][SPIRVToLLVM] Convert spv.constant scalars and vectors X-Git-Tag: llvmorg-12-init~1178 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1cfaaf645528cc2fed79617c8ca80945a1504021;p=platform%2Fupstream%2Fllvm.git [MLIR][SPIRVToLLVM] Convert spv.constant scalars and vectors This patch introduces conversion pattern for `spv.constant` with scalar and vector types. There is a special case when the constant value is a signed/unsigned integer (vector of integers). Since LLVM dialect does not have signedness semantics, the types had to be converted to signless ints. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D82936 --- diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp index 3cb8342..1ead619 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -31,6 +31,15 @@ using namespace mlir; // Utility functions //===----------------------------------------------------------------------===// +/// Returns true if the given type is a signed integer or vector type. +static bool isSignedIntegerOrVector(Type type) { + if (type.isSignedInteger()) + return true; + if (auto vecType = type.dyn_cast()) + return vecType.getElementType().isSignedInteger(); + return false; +} + /// Returns true if the given type is an unsigned integer or vector type static bool isUnsignedIntegerOrVector(Type type) { if (type.isUnsignedInteger()) @@ -197,6 +206,52 @@ public: } }; +/// Converts SPIR-V ConstantOp with scalar or vector type. +class ConstantScalarAndVectorPattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::ConstantOp constOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto srcType = constOp.getType(); + if (!srcType.isa() && !srcType.isIntOrFloat()) + return failure(); + + auto dstType = typeConverter.convertType(srcType); + if (!dstType) + return failure(); + + // SPIR-V constant can be a signed/unsigned integer, which has to be + // casted to signless integer when converting to LLVM dialect. Removing the + // sign bit may have unexpected behaviour. However, it is better to handle + // it case-by-case, given that the purpose of the conversion is not to + // cover all possible corner cases. + if (isSignedIntegerOrVector(srcType) || + isUnsignedIntegerOrVector(srcType)) { + auto *context = rewriter.getContext(); + auto signlessType = IntegerType::get(getBitWidth(srcType), context); + + if (srcType.isa()) { + auto dstElementsAttr = constOp.value().cast(); + rewriter.replaceOpWithNewOp( + constOp, dstType, + dstElementsAttr.mapValues( + signlessType, [&](const APInt &value) { return value; })); + return success(); + } + auto srcAttr = constOp.value().cast(); + auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); + } + rewriter.replaceOpWithNewOp(constOp, dstType, operands, + constOp.getAttrs()); + return success(); + } +}; + /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template @@ -573,6 +628,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns( IComparePattern, IComparePattern, + // Constant op + ConstantScalarAndVectorPattern, + // Function Call op FunctionCallPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir new file mode 100644 index 0000000..b9605e7 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.constant +//===----------------------------------------------------------------------===// + +func @bool_constant_scalar() { + // CHECK: {{.*}} = llvm.mlir.constant(true) : !llvm.i1 + %0 = spv.constant true + // CHECK: {{.*}} = llvm.mlir.constant(false) : !llvm.i1 + %1 = spv.constant false + return +} + +func @bool_constant_vector() { + // CHECK: {{.*}} = llvm.mlir.constant(dense<[true, false]> : vector<2xi1>) : !llvm<"<2 x i1>"> + %0 = constant dense<[true, false]> : vector<2xi1> + // CHECK: {{.*}} = llvm.mlir.constant(dense : vector<3xi1>) : !llvm<"<3 x i1>"> + %1 = constant dense : vector<3xi1> + return +} + +func @integer_constant_scalar() { + // CHECK: {{.*}} = llvm.mlir.constant(0 : i8) : !llvm.i8 + %0 = spv.constant 0 : i8 + // CHECK: {{.*}} = llvm.mlir.constant(-5 : i64) : !llvm.i64 + %1 = spv.constant -5 : si64 + // CHECK: {{.*}} = llvm.mlir.constant(10 : i16) : !llvm.i16 + %2 = spv.constant 10 : ui16 + return +} + +func @integer_constant_vector() { + // CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3]> : vector<2xi32>) : !llvm<"<2 x i32>"> + %0 = spv.constant dense<[2, 3]> : vector<2xi32> + // CHECK: {{.*}} = llvm.mlir.constant(dense<-4> : vector<2xi32>) : !llvm<"<2 x i32>"> + %1 = spv.constant dense<-4> : vector<2xsi32> + // CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3, 4]> : vector<3xi32>) : !llvm<"<3 x i32>"> + %2 = spv.constant dense<[2, 3, 4]> : vector<3xui32> + return +} + +func @float_constant_scalar() { + // CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f16) : !llvm.half + %0 = spv.constant 5.000000e+00 : f16 + // CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f64) : !llvm.double + %1 = spv.constant 5.000000e+00 : f64 + return +} + +func @float_constant_vector() { + // CHECK: {{.*}} = llvm.mlir.constant(dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>) : !llvm<"<2 x float>"> + %0 = spv.constant dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32> + return +}