[mlir][LLVM] Make the nested type restriction on complex constants less aggressive
authorBenjamin Kramer <benny.kra@googlemail.com>
Wed, 11 May 2022 13:12:21 +0000 (15:12 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Thu, 12 May 2022 09:47:01 +0000 (11:47 +0200)
Complex nested in other types is perfectly fine, just nested structs
aren't supported. Instead of checking whether there's nesting just check
whether the struct we're dealing with is a complex number.

Differential Revision: https://reviews.llvm.org/D125381

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/LLVMIR/llvmir-invalid.mlir
mlir/test/Target/LLVMIR/llvmir.mlir

index 1750c0c..f2b5066 100644 (file)
@@ -351,8 +351,7 @@ SetVector<Block *> getTopologicallySortedBlocks(Region &region);
 /// report it to `loc` and return nullptr.
 llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
                                 Location loc,
-                                const ModuleTranslation &moduleTranslation,
-                                bool isTopLevel = true);
+                                const ModuleTranslation &moduleTranslation);
 
 /// Creates a call to an LLVM IR intrinsic function with the given arguments.
 llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
index 127e7e1..b0e231d 100644 (file)
@@ -214,7 +214,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
       (type.isa<VectorType>() || hasVectorElementType)) {
     llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
         innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
-        moduleTranslation, /*isTopLevel=*/false);
+        moduleTranslation);
     llvm::Constant *splatVector =
         llvm::ConstantDataVector::getSplat(0, splatValue);
     SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
@@ -272,22 +272,22 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
 /// report it to `loc` and return nullptr.
 llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     llvm::Type *llvmType, Attribute attr, Location loc,
-    const ModuleTranslation &moduleTranslation, bool isTopLevel) {
+    const ModuleTranslation &moduleTranslation) {
   if (!attr)
     return llvm::UndefValue::get(llvmType);
   if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
-    if (!isTopLevel) {
-      emitError(loc, "nested struct types are not supported in constants");
+    auto arrayAttr = attr.dyn_cast<ArrayAttr>();
+    if (!arrayAttr || arrayAttr.size() != 2) {
+      emitError(loc, "expected struct type to be a complex number");
       return nullptr;
     }
-    auto arrayAttr = attr.cast<ArrayAttr>();
     llvm::Type *elementType = structType->getElementType(0);
-    llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc,
-                                           moduleTranslation, false);
+    llvm::Constant *real =
+        getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
     if (!real)
       return nullptr;
-    llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc,
-                                           moduleTranslation, false);
+    llvm::Constant *imag =
+        getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
     if (!imag)
       return nullptr;
     return llvm::ConstantStruct::get(structType, {real, imag});
@@ -336,7 +336,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
         elementType,
         elementTypeSequential ? splatAttr
                               : splatAttr.getSplatValue<Attribute>(),
-        loc, moduleTranslation, false);
+        loc, moduleTranslation);
     if (!child)
       return nullptr;
     if (llvmType->isVectorTy())
@@ -367,7 +367,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     llvm::Type *innermostType = getInnermostElementType(llvmType);
     for (auto n : elementsAttr.getValues<Attribute>()) {
       constants.push_back(
-          getLLVMConstant(innermostType, n, loc, moduleTranslation, false));
+          getLLVMConstant(innermostType, n, loc, moduleTranslation));
       if (!constants.back())
         return nullptr;
     }
index fdbbf9e..ba23c87 100644 (file)
@@ -41,14 +41,22 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
 
 // -----
 
-llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
-  // expected-error @+1 {{nested struct types are not supported in constants}}
+llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
+  // expected-error @+1 {{expected struct type to be a complex number}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
 }
 
 // -----
 
+llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
+  // expected-error @+1 {{expected struct type to be a complex number}}
+  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+}
+
+// -----
+
 llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
   // expected-error @+1 {{FloatAttr does not match expected type of the constant}}
   %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
index cd14641..b4a2dbc 100644 (file)
@@ -1122,6 +1122,18 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
   llvm.return %1 : !llvm.struct<(i32, i32)>
 }
 
+llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
+  %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+  // CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
+  llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+}
+
+llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
+  %1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
+  // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
+  llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
+}
+
 llvm.func @noreach() {
 // CHECK:    unreachable
   llvm.unreachable