From e630a502230f8779bddd214094d28fef61fde866 Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Wed, 8 Feb 2023 14:47:29 +0100 Subject: [PATCH] [mlir][llvm] Fuse MD_access_group & MD_loop import This commit moves the importing logic of access group metadata into the loop annotation importer. These two metadata imports can be grouped because access groups are only used in combination with `llvm.loop.parallel_accesses`. As a nice side effect, this commit decouples the LoopAnnotationImporter from the ModuleImport class. Differential Revision: https://reviews.llvm.org/D143577 --- mlir/include/mlir/Target/LLVMIR/ModuleImport.h | 3 - mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp | 71 +++++++++++++++++++---- mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h | 30 ++++++++-- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 53 +++-------------- mlir/test/Target/LLVMIR/Import/import-failure.ll | 18 +++++- 5 files changed, 109 insertions(+), 66 deletions(-) diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 23b1fbc..3265c32 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -302,9 +302,6 @@ private: /// to the LLVMIR dialect TBAA operations corresponding to these /// nodes. DenseMap tbaaMapping; - /// Mapping between original LLVM access group metadata nodes and the symbol - /// references pointing to the imported MLIR access group operations. - DenseMap accessGroupMapping; /// The stateful type translator (contains named structs). LLVM::TypeFromLLVMIRTranslator typeTranslator; /// Stateful debug information importer. diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp index a3218e1..a3cbf2b 100644 --- a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp @@ -16,11 +16,9 @@ using namespace mlir::LLVM::detail; namespace { /// Helper class that keeps the state of one metadata to attribute conversion. struct LoopMetadataConversion { - LoopMetadataConversion(const llvm::MDNode *node, ModuleImport &moduleImport, - Location loc, + LoopMetadataConversion(const llvm::MDNode *node, Location loc, LoopAnnotationImporter &loopAnnotationImporter) - : node(node), moduleImport(moduleImport), loc(loc), - loopAnnotationImporter(loopAnnotationImporter), + : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter), ctx(loc->getContext()){}; /// Converts this structs loop metadata node into a LoopAnnotationAttr. LoopAnnotationAttr convert(); @@ -55,7 +53,6 @@ struct LoopMetadataConversion { llvm::StringMap propertyMap; const llvm::MDNode *node; - ModuleImport &moduleImport; Location loc; LoopAnnotationImporter &loopAnnotationImporter; MLIRContext *ctx; @@ -233,7 +230,7 @@ LoopMetadataConversion::lookupFollowupNode(StringRef name) { if (*node == nullptr) return LoopAnnotationAttr(nullptr); - return loopAnnotationImporter.translate(*node, loc); + return loopAnnotationImporter.translateLoopAnnotation(*node, loc); } static bool isEmptyOrNull(const Attribute attr) { return !attr; } @@ -360,7 +357,7 @@ LoopMetadataConversion::convertParallelAccesses() { SmallVector refs; for (llvm::MDNode *node : *nodes) { FailureOr> accessGroups = - moduleImport.lookupAccessGroupAttrs(node); + loopAnnotationImporter.lookupAccessGroupAttrs(node); if (failed(accessGroups)) return emitWarning(loc) << "could not lookup access group"; llvm::append_range(refs, *accessGroups); @@ -398,8 +395,9 @@ LoopAnnotationAttr LoopMetadataConversion::convert() { parallelAccesses); } -LoopAnnotationAttr LoopAnnotationImporter::translate(const llvm::MDNode *node, - Location loc) { +LoopAnnotationAttr +LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node, + Location loc) { if (!node) return {}; @@ -409,9 +407,60 @@ LoopAnnotationAttr LoopAnnotationImporter::translate(const llvm::MDNode *node, if (it != loopMetadataMapping.end()) return it->getSecond(); - LoopAnnotationAttr attr = - LoopMetadataConversion(node, moduleImport, loc, *this).convert(); + LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert(); mapLoopMetadata(node, attr); return attr; } + +LogicalResult LoopAnnotationImporter::translateAccessGroup( + const llvm::MDNode *node, Location loc, MetadataOp metadataOp) { + SmallVector accessGroups; + if (!node->getNumOperands()) + accessGroups.push_back(node); + for (const llvm::MDOperand &operand : node->operands()) { + auto *childNode = dyn_cast(operand); + if (!childNode) + return emitWarning(loc) + << "expected access group operands to be metadata nodes"; + accessGroups.push_back(cast(operand.get())); + } + + // Convert all entries of the access group list to access group operations. + for (const llvm::MDNode *accessGroup : accessGroups) { + if (accessGroupMapping.count(accessGroup)) + continue; + // Verify the access group node is distinct and empty. + if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) + return emitWarning(loc) + << "expected an access group node to be empty and distinct"; + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(&metadataOp.getBody().back()); + auto groupOp = builder.create( + loc, llvm::formatv("group_{0}", accessGroupMapping.size()).str()); + // Add a mapping from the access group node to the symbol reference pointing + // to the newly created operation. + accessGroupMapping[accessGroup] = SymbolRefAttr::get( + builder.getContext(), metadataOp.getSymName(), + FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName())); + } + return success(); +} + +FailureOr> +LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const { + // An access group node is either a single access group or an access group + // list. + SmallVector accessGroups; + if (!node->getNumOperands()) + accessGroups.push_back(accessGroupMapping.lookup(node)); + for (const llvm::MDOperand &operand : node->operands()) { + auto *node = cast(operand.get()); + accessGroups.push_back(accessGroupMapping.lookup(node)); + } + // Exit if one of the access group node lookups failed. + if (llvm::is_contained(accessGroups, nullptr)) + return failure(); + return accessGroups; +} diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h index bd6f5ef..5d69a63 100644 --- a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h @@ -21,13 +21,28 @@ namespace mlir { namespace LLVM { namespace detail { -/// A helper class that converts a `llvm.loop` metadata node into a -/// corresponding LoopAnnotationAttr. +/// A helper class that converts llvm.loop metadata nodes into corresponding +/// LoopAnnotationAttrs and llvm.access.group nodes into +/// AccessGroupMetadataOps. class LoopAnnotationImporter { public: - explicit LoopAnnotationImporter(ModuleImport &moduleImport) - : moduleImport(moduleImport) {} - LoopAnnotationAttr translate(const llvm::MDNode *node, Location loc); + explicit LoopAnnotationImporter(OpBuilder &builder) : builder(builder) {} + LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node, + Location loc); + + /// Converts all LLVM access groups starting from node to MLIR access group + /// operations mested in the region of metadataOp. It stores a mapping from + /// every nested access group nod to the symbol pointing to the translated + /// operation. Returns success if all conversions succeed and failure + /// otherwise. + LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc, + MetadataOp metadataOp); + + /// Returns the symbol references pointing to the access group operations that + /// map to the access group nodes starting from the access group metadata + /// node. Returns failure, if any of the symbol references cannot be found. + FailureOr> + lookupAccessGroupAttrs(const llvm::MDNode *node) const; private: /// Returns the LLVM metadata corresponding to a llvm loop metadata attribute. @@ -42,8 +57,11 @@ private: "attempting to map loop options that was already mapped"); } - ModuleImport &moduleImport; + OpBuilder &builder; DenseMap loopMetadataMapping; + /// Mapping between original LLVM access group metadata nodes and the symbol + /// references pointing to the imported MLIR access group operations. + DenseMap accessGroupMapping; }; } // namespace detail diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index a5142f9..9923456 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -255,7 +255,8 @@ ModuleImport::ModuleImport(ModuleOp mlirModule, iface(mlirModule->getContext()), typeTranslator(*mlirModule->getContext()), debugImporter(std::make_unique(mlirModule)), - loopAnnotationImporter(std::make_unique(*this)) { + loopAnnotationImporter( + std::make_unique(builder)) { builder.setInsertionPointToStart(mlirModule.getBody()); } @@ -512,35 +513,11 @@ LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) { LogicalResult ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) { - // An access group node is either access group or an access group list. Start - // by collecting all access groups to translate. - SmallVector accessGroups; - if (!node->getNumOperands()) - accessGroups.push_back(node); - for (const llvm::MDOperand &operand : node->operands()) - accessGroups.push_back(cast(operand.get())); - - // Convert all entries of the access group list to access group operations. - for (const llvm::MDNode *accessGroup : accessGroups) { - if (accessGroupMapping.count(accessGroup)) - continue; - // Verify the access group node is distinct and empty. - Location loc = mlirModule.getLoc(); - if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) - return emitError(loc) << "unsupported access group node: " - << diagMD(accessGroup, llvmModule.get()); - - MetadataOp metadataOp = getGlobalMetadataOp(); - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToEnd(&metadataOp.getBody().back()); - auto groupOp = builder.create( - loc, (Twine("group_") + Twine(accessGroupMapping.size())).str()); - // Add a mapping from the access group node to the symbol reference pointing - // to the newly created operation. - accessGroupMapping[accessGroup] = SymbolRefAttr::get( - builder.getContext(), metadataOp.getSymName(), - FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName())); - } + Location loc = mlirModule.getLoc(); + if (failed(loopAnnotationImporter->translateAccessGroup( + node, loc, getGlobalMetadataOp()))) + return emitError(loc) << "unsupported access group node: " + << diagMD(node, llvmModule.get()); return success(); } @@ -1587,25 +1564,13 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb, FailureOr> ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const { - // An access group node is either a single access group or an access group - // list. - SmallVector accessGroups; - if (!node->getNumOperands()) - accessGroups.push_back(accessGroupMapping.lookup(node)); - for (const llvm::MDOperand &operand : node->operands()) { - auto *node = cast(operand.get()); - accessGroups.push_back(accessGroupMapping.lookup(node)); - } - // Exit if one of the access group node lookups failed. - if (llvm::is_contained(accessGroups, nullptr)) - return failure(); - return accessGroups; + return loopAnnotationImporter->lookupAccessGroupAttrs(node); } LoopAnnotationAttr ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node, Location loc) const { - return loopAnnotationImporter->translate(node, loc); + return loopAnnotationImporter->translateLoopAnnotation(node, loc); } OwningOpRef diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll index ea35b92..12a605a 100644 --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -241,7 +241,8 @@ define dso_local void @tbaa(ptr %0) { ; // ----- ; CHECK: import-failure.ll -; CHECK-SAME: error: unsupported access group node: !0 = !{} +; CHECK-SAME: warning: expected an access group node to be empty and distinct +; CHECK: error: unsupported access group node: !0 = !{} define void @access_group(ptr %arg1) { %1 = load i32, ptr %arg1, !llvm.access.group !0 ret void @@ -252,7 +253,8 @@ define void @access_group(ptr %arg1) { ; // ----- ; CHECK: import-failure.ll -; CHECK-SAME: error: unsupported access group node: !1 = distinct !{!"unsupported access group"} +; CHECK-SAME: warning: expected an access group node to be empty and distinct +; CHECK: error: unsupported access group node: !0 = !{!1} define void @access_group(ptr %arg1) { %1 = load i32, ptr %arg1, !llvm.access.group !0 ret void @@ -264,6 +266,18 @@ define void @access_group(ptr %arg1) { ; // ----- ; CHECK: import-failure.ll +; CHECK-SAME: warning: expected access group operands to be metadata nodes +; CHECK: error: unsupported access group node: !0 = !{i1 false} +define void @access_group(ptr %arg1) { + %1 = load i32, ptr %arg1, !llvm.access.group !0 + ret void +} + +!0 = !{i1 false} + +; // ----- + +; CHECK: import-failure.ll ; CHECK-SAME: warning: expected all loop properties to be either debug locations or metadata nodes ; CHECK: import-failure.ll ; CHECK-SAME: warning: unhandled metadata: !0 = distinct !{!0, i32 42} -- 2.7.4