From 27dad99622bb16fc7ba94beda26dded9023bf2cd Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 11 May 2022 15:12:21 +0200 Subject: [PATCH] [mlir][LLVM] Make the nested type restriction on complex constants less aggressive Complex nested in other types is perfectly fine, just nested structs aren't supported. Instead of checking whether there's nesting just check whether the struct we're dealing with is a complex number. Differential Revision: https://reviews.llvm.org/D125381 --- .../include/mlir/Target/LLVMIR/ModuleTranslation.h | 3 +-- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 22 +++++++++++----------- mlir/test/Target/LLVMIR/llvmir-invalid.mlir | 12 ++++++++++-- mlir/test/Target/LLVMIR/llvmir.mlir | 12 ++++++++++++ 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 1750c0c..f2b5066 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -351,8 +351,7 @@ SetVector getTopologicallySortedBlocks(Region ®ion); /// report it to `loc` and return nullptr. llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, - const ModuleTranslation &moduleTranslation, - bool isTopLevel = true); + const ModuleTranslation &moduleTranslation); /// Creates a call to an LLVM IR intrinsic function with the given arguments. llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder, diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 127e7e1..b0e231d 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -214,7 +214,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, (type.isa() || hasVectorElementType)) { llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( innermostLLVMType, denseElementsAttr.getSplatValue(), loc, - moduleTranslation, /*isTopLevel=*/false); + moduleTranslation); llvm::Constant *splatVector = llvm::ConstantDataVector::getSplat(0, splatValue); SmallVector constants(numAggregates, splatVector); @@ -272,22 +272,22 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, /// report it to `loc` and return nullptr. llvm::Constant *mlir::LLVM::detail::getLLVMConstant( llvm::Type *llvmType, Attribute attr, Location loc, - const ModuleTranslation &moduleTranslation, bool isTopLevel) { + const ModuleTranslation &moduleTranslation) { if (!attr) return llvm::UndefValue::get(llvmType); if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { - if (!isTopLevel) { - emitError(loc, "nested struct types are not supported in constants"); + auto arrayAttr = attr.dyn_cast(); + if (!arrayAttr || arrayAttr.size() != 2) { + emitError(loc, "expected struct type to be a complex number"); return nullptr; } - auto arrayAttr = attr.cast(); llvm::Type *elementType = structType->getElementType(0); - llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc, - moduleTranslation, false); + llvm::Constant *real = + getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation); if (!real) return nullptr; - llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc, - moduleTranslation, false); + llvm::Constant *imag = + getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation); if (!imag) return nullptr; return llvm::ConstantStruct::get(structType, {real, imag}); @@ -336,7 +336,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( elementType, elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), - loc, moduleTranslation, false); + loc, moduleTranslation); if (!child) return nullptr; if (llvmType->isVectorTy()) @@ -367,7 +367,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( llvm::Type *innermostType = getInnermostElementType(llvmType); for (auto n : elementsAttr.getValues()) { constants.push_back( - getLLVMConstant(innermostType, n, loc, moduleTranslation, false)); + getLLVMConstant(innermostType, n, loc, moduleTranslation)); if (!constants.back()) return nullptr; } diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir index fdbbf9e..ba23c87 100644 --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -41,14 +41,22 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 { // ----- -llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { - // expected-error @+1 {{nested struct types are not supported in constants}} +llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { + // expected-error @+1 {{expected struct type to be a complex number}} %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> } // ----- +llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> { + // expected-error @+1 {{expected struct type to be a complex number}} + %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> + llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> +} + +// ----- + llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> { // expected-error @+1 {{FloatAttr does not match expected type of the constant}} %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)> diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index cd14641..b4a2dbcf 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1122,6 +1122,18 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> { llvm.return %1 : !llvm.struct<(i32, i32)> } +llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> { + %1 = llvm.mlir.constant(dense<(0, 1)> : tensor>) : !llvm.array<2 x !llvm.struct<(i32, i32)>> + // CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }] + llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>> +} + +llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> { + %1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>> + // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]] + llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> +} + llvm.func @noreach() { // CHECK: unreachable llvm.unreachable -- 2.7.4