[mlir][llvm] Fuse MD_access_group & MD_loop import
authorChristian Ulmann <christian.ulmann@nextsilicon.com>
Wed, 8 Feb 2023 13:47:29 +0000 (14:47 +0100)
committerChristian Ulmann <christian.ulmann@nextsilicon.com>
Thu, 9 Feb 2023 13:43:02 +0000 (14:43 +0100)
This commit moves the importing logic of access group metadata into the
loop annotation importer. These two metadata imports can be grouped
because access groups are only used in combination with
`llvm.loop.parallel_accesses`.

As a nice side effect, this commit decouples the LoopAnnotationImporter
from the ModuleImport class.

Differential Revision: https://reviews.llvm.org/D143577

mlir/include/mlir/Target/LLVMIR/ModuleImport.h
mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp
mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/test/Target/LLVMIR/Import/import-failure.ll

index 23b1fbc29dd724a56fdfb72e8a2c0f08c9c149ae..3265c323724102d43cf4ce6d0c4c48b0ad9bacae 100644 (file)
@@ -302,9 +302,6 @@ private:
   /// to the LLVMIR dialect TBAA operations corresponding to these
   /// nodes.
   DenseMap<const llvm::MDNode *, SymbolRefAttr> tbaaMapping;
-  /// Mapping between original LLVM access group metadata nodes and the symbol
-  /// references pointing to the imported MLIR access group operations.
-  DenseMap<const llvm::MDNode *, SymbolRefAttr> accessGroupMapping;
   /// The stateful type translator (contains named structs).
   LLVM::TypeFromLLVMIRTranslator typeTranslator;
   /// Stateful debug information importer.
index a3218e13307e4ccec49f2c729b3b60086ea950dc..a3cbf2bcd47c12edd51ba26472d7eeccf0231f1f 100644 (file)
@@ -16,11 +16,9 @@ using namespace mlir::LLVM::detail;
 namespace {
 /// Helper class that keeps the state of one metadata to attribute conversion.
 struct LoopMetadataConversion {
-  LoopMetadataConversion(const llvm::MDNode *node, ModuleImport &moduleImport,
-                         Location loc,
+  LoopMetadataConversion(const llvm::MDNode *node, Location loc,
                          LoopAnnotationImporter &loopAnnotationImporter)
-      : node(node), moduleImport(moduleImport), loc(loc),
-        loopAnnotationImporter(loopAnnotationImporter),
+      : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
         ctx(loc->getContext()){};
   /// Converts this structs loop metadata node into a LoopAnnotationAttr.
   LoopAnnotationAttr convert();
@@ -55,7 +53,6 @@ struct LoopMetadataConversion {
 
   llvm::StringMap<const llvm::MDNode *> propertyMap;
   const llvm::MDNode *node;
-  ModuleImport &moduleImport;
   Location loc;
   LoopAnnotationImporter &loopAnnotationImporter;
   MLIRContext *ctx;
@@ -233,7 +230,7 @@ LoopMetadataConversion::lookupFollowupNode(StringRef name) {
   if (*node == nullptr)
     return LoopAnnotationAttr(nullptr);
 
-  return loopAnnotationImporter.translate(*node, loc);
+  return loopAnnotationImporter.translateLoopAnnotation(*node, loc);
 }
 
 static bool isEmptyOrNull(const Attribute attr) { return !attr; }
@@ -360,7 +357,7 @@ LoopMetadataConversion::convertParallelAccesses() {
   SmallVector<SymbolRefAttr> refs;
   for (llvm::MDNode *node : *nodes) {
     FailureOr<SmallVector<SymbolRefAttr>> accessGroups =
-        moduleImport.lookupAccessGroupAttrs(node);
+        loopAnnotationImporter.lookupAccessGroupAttrs(node);
     if (failed(accessGroups))
       return emitWarning(loc) << "could not lookup access group";
     llvm::append_range(refs, *accessGroups);
@@ -398,8 +395,9 @@ LoopAnnotationAttr LoopMetadataConversion::convert() {
       parallelAccesses);
 }
 
-LoopAnnotationAttr LoopAnnotationImporter::translate(const llvm::MDNode *node,
-                                                     Location loc) {
+LoopAnnotationAttr
+LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node,
+                                                Location loc) {
   if (!node)
     return {};
 
@@ -409,9 +407,60 @@ LoopAnnotationAttr LoopAnnotationImporter::translate(const llvm::MDNode *node,
   if (it != loopMetadataMapping.end())
     return it->getSecond();
 
-  LoopAnnotationAttr attr =
-      LoopMetadataConversion(node, moduleImport, loc, *this).convert();
+  LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert();
 
   mapLoopMetadata(node, attr);
   return attr;
 }
+
+LogicalResult LoopAnnotationImporter::translateAccessGroup(
+    const llvm::MDNode *node, Location loc, MetadataOp metadataOp) {
+  SmallVector<const llvm::MDNode *> accessGroups;
+  if (!node->getNumOperands())
+    accessGroups.push_back(node);
+  for (const llvm::MDOperand &operand : node->operands()) {
+    auto *childNode = dyn_cast<llvm::MDNode>(operand);
+    if (!childNode)
+      return emitWarning(loc)
+             << "expected access group operands to be metadata nodes";
+    accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
+  }
+
+  // Convert all entries of the access group list to access group operations.
+  for (const llvm::MDNode *accessGroup : accessGroups) {
+    if (accessGroupMapping.count(accessGroup))
+      continue;
+    // Verify the access group node is distinct and empty.
+    if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
+      return emitWarning(loc)
+             << "expected an access group node to be empty and distinct";
+
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToEnd(&metadataOp.getBody().back());
+    auto groupOp = builder.create<AccessGroupMetadataOp>(
+        loc, llvm::formatv("group_{0}", accessGroupMapping.size()).str());
+    // Add a mapping from the access group node to the symbol reference pointing
+    // to the newly created operation.
+    accessGroupMapping[accessGroup] = SymbolRefAttr::get(
+        builder.getContext(), metadataOp.getSymName(),
+        FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName()));
+  }
+  return success();
+}
+
+FailureOr<SmallVector<SymbolRefAttr>>
+LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
+  // An access group node is either a single access group or an access group
+  // list.
+  SmallVector<SymbolRefAttr> accessGroups;
+  if (!node->getNumOperands())
+    accessGroups.push_back(accessGroupMapping.lookup(node));
+  for (const llvm::MDOperand &operand : node->operands()) {
+    auto *node = cast<llvm::MDNode>(operand.get());
+    accessGroups.push_back(accessGroupMapping.lookup(node));
+  }
+  // Exit if one of the access group node lookups failed.
+  if (llvm::is_contained(accessGroups, nullptr))
+    return failure();
+  return accessGroups;
+}
index bd6f5ef350e64d6d0aaa3f93ebe74d5a55f525c4..5d69a63a21502cdc0986e7a5f8c2ba84d3728ab9 100644 (file)
@@ -21,13 +21,28 @@ namespace mlir {
 namespace LLVM {
 namespace detail {
 
-/// A helper class that converts a `llvm.loop` metadata node into a
-/// corresponding LoopAnnotationAttr.
+/// A helper class that converts llvm.loop metadata nodes into corresponding
+/// LoopAnnotationAttrs and llvm.access.group nodes into
+/// AccessGroupMetadataOps.
 class LoopAnnotationImporter {
 public:
-  explicit LoopAnnotationImporter(ModuleImport &moduleImport)
-      : moduleImport(moduleImport) {}
-  LoopAnnotationAttr translate(const llvm::MDNode *node, Location loc);
+  explicit LoopAnnotationImporter(OpBuilder &builder) : builder(builder) {}
+  LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node,
+                                             Location loc);
+
+  /// Converts all LLVM access groups starting from node to MLIR access group
+  /// operations mested in the region of metadataOp. It stores a mapping from
+  /// every nested access group nod to the symbol pointing to the translated
+  /// operation. Returns success if all conversions succeed and failure
+  /// otherwise.
+  LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc,
+                                     MetadataOp metadataOp);
+
+  /// Returns the symbol references pointing to the access group operations that
+  /// map to the access group nodes starting from the access group metadata
+  /// node. Returns failure, if any of the symbol references cannot be found.
+  FailureOr<SmallVector<SymbolRefAttr>>
+  lookupAccessGroupAttrs(const llvm::MDNode *node) const;
 
 private:
   /// Returns the LLVM metadata corresponding to a llvm loop metadata attribute.
@@ -42,8 +57,11 @@ private:
            "attempting to map loop options that was already mapped");
   }
 
-  ModuleImport &moduleImport;
+  OpBuilder &builder;
   DenseMap<const llvm::MDNode *, LoopAnnotationAttr> loopMetadataMapping;
+  /// Mapping between original LLVM access group metadata nodes and the symbol
+  /// references pointing to the imported MLIR access group operations.
+  DenseMap<const llvm::MDNode *, SymbolRefAttr> accessGroupMapping;
 };
 
 } // namespace detail
index a5142f96fe0a394ffb4b27d8805c956d6de3c75d..992345686e7164734974a9211542dac6bbbee085 100644 (file)
@@ -255,7 +255,8 @@ ModuleImport::ModuleImport(ModuleOp mlirModule,
       iface(mlirModule->getContext()),
       typeTranslator(*mlirModule->getContext()),
       debugImporter(std::make_unique<DebugImporter>(mlirModule)),
-      loopAnnotationImporter(std::make_unique<LoopAnnotationImporter>(*this)) {
+      loopAnnotationImporter(
+          std::make_unique<LoopAnnotationImporter>(builder)) {
   builder.setInsertionPointToStart(mlirModule.getBody());
 }
 
@@ -512,35 +513,11 @@ LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) {
 
 LogicalResult
 ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) {
-  // An access group node is either access group or an access group list. Start
-  // by collecting all access groups to translate.
-  SmallVector<const llvm::MDNode *> accessGroups;
-  if (!node->getNumOperands())
-    accessGroups.push_back(node);
-  for (const llvm::MDOperand &operand : node->operands())
-    accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
-
-  // Convert all entries of the access group list to access group operations.
-  for (const llvm::MDNode *accessGroup : accessGroups) {
-    if (accessGroupMapping.count(accessGroup))
-      continue;
-    // Verify the access group node is distinct and empty.
-    Location loc = mlirModule.getLoc();
-    if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
-      return emitError(loc) << "unsupported access group node: "
-                            << diagMD(accessGroup, llvmModule.get());
-
-    MetadataOp metadataOp = getGlobalMetadataOp();
-    OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToEnd(&metadataOp.getBody().back());
-    auto groupOp = builder.create<AccessGroupMetadataOp>(
-        loc, (Twine("group_") + Twine(accessGroupMapping.size())).str());
-    // Add a mapping from the access group node to the symbol reference pointing
-    // to the newly created operation.
-    accessGroupMapping[accessGroup] = SymbolRefAttr::get(
-        builder.getContext(), metadataOp.getSymName(),
-        FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName()));
-  }
+  Location loc = mlirModule.getLoc();
+  if (failed(loopAnnotationImporter->translateAccessGroup(
+          node, loc, getGlobalMetadataOp())))
+    return emitError(loc) << "unsupported access group node: "
+                          << diagMD(node, llvmModule.get());
   return success();
 }
 
@@ -1587,25 +1564,13 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
 
 FailureOr<SmallVector<SymbolRefAttr>>
 ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
-  // An access group node is either a single access group or an access group
-  // list.
-  SmallVector<SymbolRefAttr> accessGroups;
-  if (!node->getNumOperands())
-    accessGroups.push_back(accessGroupMapping.lookup(node));
-  for (const llvm::MDOperand &operand : node->operands()) {
-    auto *node = cast<llvm::MDNode>(operand.get());
-    accessGroups.push_back(accessGroupMapping.lookup(node));
-  }
-  // Exit if one of the access group node lookups failed.
-  if (llvm::is_contained(accessGroups, nullptr))
-    return failure();
-  return accessGroups;
+  return loopAnnotationImporter->lookupAccessGroupAttrs(node);
 }
 
 LoopAnnotationAttr
 ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
                                           Location loc) const {
-  return loopAnnotationImporter->translate(node, loc);
+  return loopAnnotationImporter->translateLoopAnnotation(node, loc);
 }
 
 OwningOpRef<ModuleOp>
index ea35b921648167ef1c4b9edd6282f82cc219941a..12a605a72d344131d545d4a29c5c0993d28ff3b8 100644 (file)
@@ -241,7 +241,8 @@ define dso_local void @tbaa(ptr %0) {
 ; // -----
 
 ; CHECK:      import-failure.ll
-; CHECK-SAME: error: unsupported access group node: !0 = !{}
+; CHECK-SAME: warning: expected an access group node to be empty and distinct
+; CHECK:      error: unsupported access group node: !0 = !{}
 define void @access_group(ptr %arg1) {
   %1 = load i32, ptr %arg1, !llvm.access.group !0
   ret void
@@ -252,7 +253,8 @@ define void @access_group(ptr %arg1) {
 ; // -----
 
 ; CHECK:      import-failure.ll
-; CHECK-SAME: error: unsupported access group node: !1 = distinct !{!"unsupported access group"}
+; CHECK-SAME: warning: expected an access group node to be empty and distinct
+; CHECK:      error: unsupported access group node: !0 = !{!1}
 define void @access_group(ptr %arg1) {
   %1 = load i32, ptr %arg1, !llvm.access.group !0
   ret void
@@ -263,6 +265,18 @@ define void @access_group(ptr %arg1) {
 
 ; // -----
 
+; CHECK:      import-failure.ll
+; CHECK-SAME: warning: expected access group operands to be metadata nodes
+; CHECK:      error: unsupported access group node: !0 = !{i1 false}
+define void @access_group(ptr %arg1) {
+  %1 = load i32, ptr %arg1, !llvm.access.group !0
+  ret void
+}
+
+!0 = !{i1 false}
+
+; // -----
+
 ; CHECK:      import-failure.ll
 ; CHECK-SAME: warning: expected all loop properties to be either debug locations or metadata nodes
 ; CHECK:      import-failure.ll