[mlir][LLVMIR] Use insertelement if needed when translating ConstantAggregate
authorMin-Yih Hsu <minyihh@uci.edu>
Thu, 19 May 2022 18:58:26 +0000 (11:58 -0700)
committerMin-Yih Hsu <minyihh@uci.edu>
Wed, 15 Jun 2022 21:33:47 +0000 (14:33 -0700)
When translating from a llvm::ConstantAggregate with vector type, we
should lower to insertelement operations (if needed) rather than using
insertvalue.

Differential Revision: https://reviews.llvm.org/D127534

mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/test/Target/LLVMIR/Import/constant-aggregate.ll

index 53e0a31..32237bd 100644 (file)
@@ -533,15 +533,28 @@ Value Importer::processConstant(llvm::Constant *c) {
     Type rootType = processType(c->getType());
     if (!rootType)
       return nullptr;
+    bool useInsertValue = rootType.isa<LLVMArrayType, LLVMStructType>();
+    assert((useInsertValue || LLVM::isCompatibleVectorType(rootType)) &&
+           "unrecognized aggregate type");
     Value root = bEntry.create<UndefOp>(unknownLoc, rootType);
     for (unsigned i = 0; i < numElements; ++i) {
       llvm::Constant *element = getElement(i);
       Value elementValue = processConstant(element);
       if (!elementValue)
         return nullptr;
-      ArrayAttr indexAttr = bEntry.getI32ArrayAttr({static_cast<int32_t>(i)});
-      root = bEntry.create<InsertValueOp>(UnknownLoc::get(context), rootType,
-                                          root, elementValue, indexAttr);
+      if (useInsertValue) {
+        ArrayAttr indexAttr = bEntry.getI32ArrayAttr({static_cast<int32_t>(i)});
+        root = bEntry.create<InsertValueOp>(UnknownLoc::get(context), rootType,
+                                            root, elementValue, indexAttr);
+      } else {
+        Attribute indexAttr = bEntry.getI32IntegerAttr(static_cast<int32_t>(i));
+        Value indexValue = bEntry.create<ConstantOp>(
+            unknownLoc, bEntry.getI32Type(), indexAttr);
+        if (!indexValue)
+          return nullptr;
+        root = bEntry.create<InsertElementOp>(
+            UnknownLoc::get(context), rootType, root, elementValue, indexValue);
+      }
     }
     return root;
   }
index 3a8d270..71919a0 100644 (file)
 %NestedAggType = type {%SimpleAggType, %SimpleAggType*}
 @nestedAgg = global %NestedAggType { %SimpleAggType{i32 1, i8 2, i16 3, i32 4}, %SimpleAggType* null }
 
+; CHECK: %[[C0:.+]] = llvm.mlir.null : !llvm.ptr<struct<"SimpleAggType", (i32, i8, i16, i32)>>
+; CHECK: %[[C1:.+]] = llvm.mlir.null : !llvm.ptr<struct<"SimpleAggType", (i32, i8, i16, i32)>>
+; CHECK: %[[ROOT:.+]] = llvm.mlir.undef : !llvm.vec<2 x ptr<struct<"SimpleAggType", (i32, i8, i16, i32)>>>
+; CHECK: %[[P0:.+]] = llvm.mlir.constant(0 : i32) : i32
+; CHECK: %[[CHAIN0:.+]] = llvm.insertelement %[[C1]], %[[ROOT]][%[[P0]] : i32] : !llvm.vec<2 x ptr<struct<"SimpleAggType", (i32, i8, i16, i32)>>>
+; CHECK: %[[P1:.+]] = llvm.mlir.constant(1 : i32) : i32
+; CHECK: %[[CHAIN1:.+]] = llvm.insertelement %[[C0]], %[[CHAIN0]][%[[P1]] : i32] : !llvm.vec<2 x ptr<struct<"SimpleAggType", (i32, i8, i16, i32)>>>
+; CHECK: llvm.return %[[CHAIN1]] : !llvm.vec<2 x ptr<struct<"SimpleAggType", (i32, i8, i16, i32)>>>
+@vectorAgg = global <2 x %SimpleAggType*> <%SimpleAggType* null, %SimpleAggType* null>