[mlir][llvm] Adapt loop metadata to match llvm
authorChristian Ulmann <christian.ulmann@nextsilicon.com>
Fri, 10 Feb 2023 13:33:41 +0000 (14:33 +0100)
committerChristian Ulmann <christian.ulmann@nextsilicon.com>
Fri, 10 Feb 2023 13:44:15 +0000 (14:44 +0100)
This commit adds support for the "llvm.loop.isvectorized" metadata and
ensures that the unroll followups match llvm's naming.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D143730

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp
mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
mlir/test/Dialect/LLVMIR/loop-metadata.mlir
mlir/test/Target/LLVMIR/Import/metadata-loop.ll
mlir/test/Target/LLVMIR/loop-metadata.mlir

index 66370e7..06783b1 100644 (file)
@@ -95,8 +95,9 @@ def LoopUnrollAttr : LLVM_Attr<"LoopUnroll", "loop_unroll"> {
     OptionalParameter<"IntegerAttr">:$count,
     OptionalParameter<"BoolAttr">:$runtimeDisable,
     OptionalParameter<"BoolAttr">:$full,
-    OptionalParameter<"LoopAnnotationAttr">:$followup,
-    OptionalParameter<"LoopAnnotationAttr">:$followupRemainder
+    OptionalParameter<"LoopAnnotationAttr">:$followupUnrolled,
+    OptionalParameter<"LoopAnnotationAttr">:$followupRemainder,
+    OptionalParameter<"LoopAnnotationAttr">:$followupAll
   );
 
   let assemblyFormat = "`<` struct(params) `>`";
@@ -186,6 +187,7 @@ def LoopAnnotationAttr : LLVM_Attr<"LoopAnnotation", "loop_annotation"> {
     OptionalParameter<"LoopDistributeAttr">:$distribute,
     OptionalParameter<"LoopPipelineAttr">:$pipeline,
     OptionalParameter<"BoolAttr">:$mustProgress,
+    OptionalParameter<"BoolAttr">:$isVectorized,
     OptionalArrayRefParameter<"SymbolRefAttr">:$parallelAccesses
   );
 
index a3cbf2b..53e6f9c 100644 (file)
@@ -33,6 +33,7 @@ struct LoopMetadataConversion {
   /// specified name, or failure, if the node is ill-formatted.
   FailureOr<BoolAttr> lookupUnitNode(StringRef name);
   FailureOr<BoolAttr> lookupBoolNode(StringRef name, bool negated = false);
+  FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name);
   FailureOr<IntegerAttr> lookupIntNode(StringRef name);
   FailureOr<llvm::MDNode *> lookupMDNode(StringRef name);
   FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name);
@@ -155,6 +156,27 @@ FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name,
   return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated);
 }
 
+FailureOr<BoolAttr>
+LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) {
+  const llvm::MDNode *property = lookupAndEraseProperty(name);
+  if (!property)
+    return BoolAttr(nullptr);
+
+  auto emitNodeWarning = [&]() {
+    return emitWarning(loc)
+           << "expected metadata node " << name << " to hold an integer value";
+  };
+
+  if (property->getNumOperands() != 2)
+    return emitNodeWarning();
+  llvm::ConstantInt *val =
+      llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
+  if (!val || val->getBitWidth() != 32)
+    return emitNodeWarning();
+
+  return BoolAttr::get(ctx, val->getValue().getLimitedValue(1));
+}
+
 FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) {
   const llvm::MDNode *property = lookupAndEraseProperty(name);
   if (!property)
@@ -287,13 +309,16 @@ FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() {
   FailureOr<BoolAttr> runtimeDisable =
       lookupUnitNode("llvm.loop.unroll.runtime.disable");
   FailureOr<BoolAttr> full = lookupUnitNode("llvm.loop.unroll.full");
-  FailureOr<LoopAnnotationAttr> followup =
-      lookupFollowupNode("llvm.loop.unroll.followup");
+  FailureOr<LoopAnnotationAttr> followupUnrolled =
+      lookupFollowupNode("llvm.loop.unroll.followup_unrolled");
   FailureOr<LoopAnnotationAttr> followupRemainder =
       lookupFollowupNode("llvm.loop.unroll.followup_remainder");
+  FailureOr<LoopAnnotationAttr> followupAll =
+      lookupFollowupNode("llvm.loop.unroll.followup_all");
 
   return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable,
-                                         full, followup, followupRemainder);
+                                         full, followupUnrolled,
+                                         followupRemainder, followupAll);
 }
 
 FailureOr<LoopUnrollAndJamAttr>
@@ -379,6 +404,8 @@ LoopAnnotationAttr LoopMetadataConversion::convert() {
   FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr();
   FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr();
   FailureOr<BoolAttr> mustProgress = lookupUnitNode("llvm.loop.mustprogress");
+  FailureOr<BoolAttr> isVectorized =
+      lookupIntNodeAsBoolAttr("llvm.loop.isvectorized");
   FailureOr<SmallVector<SymbolRefAttr>> parallelAccesses =
       convertParallelAccesses();
 
@@ -392,7 +419,7 @@ LoopAnnotationAttr LoopMetadataConversion::convert() {
   return createIfNonNull<LoopAnnotationAttr>(
       ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr,
       unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, mustProgress,
-      parallelAccesses);
+      isVectorized, parallelAccesses);
 }
 
 LoopAnnotationAttr
index b8b6e3c..5864b48 100644 (file)
@@ -29,6 +29,7 @@ struct LoopAnnotationConversion {
   /// Conversion functions for different payload attribute kinds.
   void addUnitNode(StringRef name);
   void addUnitNode(StringRef name, BoolAttr attr);
+  void addI32NodeWithVal(StringRef name, uint32_t val);
   void convertBoolNode(StringRef name, BoolAttr attr, bool negated = false);
   void convertI32Node(StringRef name, IntegerAttr attr);
   void convertFollowupNode(StringRef name, LoopAnnotationAttr attr);
@@ -61,6 +62,14 @@ void LoopAnnotationConversion::addUnitNode(StringRef name, BoolAttr attr) {
     addUnitNode(name);
 }
 
+void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) {
+  llvm::Constant *cstValue = llvm::ConstantInt::get(
+      llvm::IntegerType::get(ctx, /*NumBits=*/32), val, /*isSigned=*/false);
+  metadataNodes.push_back(
+      llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
+                              llvm::ConstantAsMetadata::get(cstValue)}));
+}
+
 void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr,
                                                bool negated) {
   if (!attr)
@@ -76,12 +85,7 @@ void LoopAnnotationConversion::convertI32Node(StringRef name,
                                               IntegerAttr attr) {
   if (!attr)
     return;
-  uint32_t val = attr.getInt();
-  llvm::Constant *cstValue = llvm::ConstantInt::get(
-      llvm::IntegerType::get(ctx, /*NumBits=*/32), val, /*isSigned=*/false);
-  metadataNodes.push_back(
-      llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
-                              llvm::ConstantAsMetadata::get(cstValue)}));
+  addI32NodeWithVal(name, attr.getInt());
 }
 
 void LoopAnnotationConversion::convertFollowupNode(StringRef name,
@@ -122,9 +126,12 @@ void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr options) {
   convertBoolNode("llvm.loop.unroll.runtime.disable",
                   options.getRuntimeDisable());
   addUnitNode("llvm.loop.unroll.full", options.getFull());
-  convertFollowupNode("llvm.loop.unroll.followup", options.getFollowup());
+  convertFollowupNode("llvm.loop.unroll.followup_unrolled",
+                      options.getFollowupUnrolled());
   convertFollowupNode("llvm.loop.unroll.followup_remainder",
                       options.getFollowupRemainder());
+  convertFollowupNode("llvm.loop.unroll.followup_all",
+                      options.getFollowupAll());
 }
 
 void LoopAnnotationConversion::convertLoopOptions(
@@ -177,6 +184,9 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
 
   addUnitNode("llvm.loop.disable_nonforced", attr.getDisableNonforced());
   addUnitNode("llvm.loop.mustprogress", attr.getMustProgress());
+  // "isvectorized" is encoded as an i32 value.
+  if (BoolAttr isVectorized = attr.getIsVectorized())
+    addI32NodeWithVal("llvm.loop.isvectorized", isVectorized.getValue());
 
   if (auto options = attr.getVectorize())
     convertLoopOptions(options);
index edd9592..0d43432 100644 (file)
 // CHECK-DAG: #[[INTERLEAVE:.*]] = #llvm.loop_interleave<count = 32 : i32>
 #interleave = #llvm.loop_interleave<count = 32 : i32>
 
-// CHECK-DAG: #[[UNROLL:.*]] = #llvm.loop_unroll<disable = true, count = 32 : i32, runtimeDisable = true, full = false, followup = #[[FOLLOWUP]], followupRemainder = #[[FOLLOWUP]]>
+// CHECK-DAG: #[[UNROLL:.*]] = #llvm.loop_unroll<disable = true, count = 32 : i32, runtimeDisable = true, full = false, followupUnrolled = #[[FOLLOWUP]], followupRemainder = #[[FOLLOWUP]], followupAll = #[[FOLLOWUP]]>
 #unroll = #llvm.loop_unroll<
   disable = true, count = 32 : i32, runtimeDisable = true, full = false,
-  followup = #followup, followupRemainder = #followup
+  followupUnrolled = #followup, followupRemainder = #followup, followupAll = #followup
 >
 
 // CHECK-DAG: #[[UNROLL_AND_JAM:.*]] = #llvm.loop_unroll_and_jam<disable = false, count = 16 : i32, followupOuter = #[[FOLLOWUP]], followupInner = #[[FOLLOWUP]], followupRemainderOuter = #[[FOLLOWUP]], followupRemainderInner = #[[FOLLOWUP]], followupAll = #[[FOLLOWUP]]>
@@ -44,6 +44,7 @@
 // CHECK-DAG: licm = #[[LICM]]
 // CHECK-DAG: distribute = #[[DISTRIBUTE]]
 // CHECK-DAG: pipeline = #[[PIPELINE]]
+// CHECK-DAG: isVectorized = false
 // CHECK-DAG: parallelAccesses = @metadata::@group1, @metadata::@group2>
 #loopMD = #llvm.loop_annotation<disableNonforced = false,
         mustProgress = true,
@@ -54,6 +55,7 @@
         licm = #licm,
         distribute = #distribute,
         pipeline = #pipeline,
+        isVectorized = false,
         parallelAccesses = @metadata::@group1, @metadata::@group2>
 
 // CHECK: llvm.func @loop_annotation
index 93d29b0..1ddd5e2 100644 (file)
@@ -28,7 +28,7 @@ define void @access_group(ptr %arg1) {
 
 ; // -----
 
-; CHECK: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation<disableNonforced = true, mustProgress = true>
+; CHECK: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation<disableNonforced = true, mustProgress = true, isVectorized = true>
 
 ; CHECK-LABEL: @simple
 define void @simple(i64 %n, ptr %A) {
@@ -39,9 +39,10 @@ end:
   ret void
 }
 
-!1 = distinct !{!1, !2, !3}
+!1 = distinct !{!1, !2, !3, !4}
 !2 = !{!"llvm.loop.disable_nonforced"}
 !3 = !{!"llvm.loop.mustprogress"}
+!4 = !{!"llvm.loop.isvectorized", i32 1}
 
 ; // -----
 
@@ -90,7 +91,7 @@ end:
 ; // -----
 
 ; CHECK-DAG: #[[FOLLOWUP:.*]] = #llvm.loop_annotation<disableNonforced = true>
-; CHECK-DAG: #[[UNROLL_ATTR:.*]] = #llvm.loop_unroll<disable = false, count = 16 : i32, runtimeDisable = true, full = true, followup = #[[FOLLOWUP]], followupRemainder = #[[FOLLOWUP]]>
+; CHECK-DAG: #[[UNROLL_ATTR:.*]] = #llvm.loop_unroll<disable = false, count = 16 : i32, runtimeDisable = true, full = true, followupUnrolled = #[[FOLLOWUP]], followupRemainder = #[[FOLLOWUP]], followupAll = #[[FOLLOWUP]]>
 ; CHECK-DAG: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation<unroll = #[[UNROLL_ATTR]]>
 
 ; CHECK-LABEL: @unroll
@@ -102,16 +103,17 @@ end:
   ret void
 }
 
-!1 = distinct !{!1, !2, !3, !4, !5, !6, !7}
+!1 = distinct !{!1, !2, !3, !4, !5, !6, !7, !8}
 !2 = !{!"llvm.loop.unroll.enable"}
 !3 = !{!"llvm.loop.unroll.count", i32 16}
 !4 = !{!"llvm.loop.unroll.runtime.disable"}
 !5 = !{!"llvm.loop.unroll.full"}
-!6 = !{!"llvm.loop.unroll.followup", !8}
-!7 = !{!"llvm.loop.unroll.followup_remainder", !8}
+!6 = !{!"llvm.loop.unroll.followup_unrolled", !9}
+!7 = !{!"llvm.loop.unroll.followup_remainder", !9}
+!8 = !{!"llvm.loop.unroll.followup_all", !9}
 
-!8 = distinct !{!8, !9}
-!9 = !{!"llvm.loop.disable_nonforced"}
+!9 = distinct !{!9, !10}
+!10 = !{!"llvm.loop.disable_nonforced"}
 
 ; // -----
 
index f2eaa1c..9bed3ae 100644 (file)
@@ -26,6 +26,18 @@ llvm.func @mustprogress() {
 
 // -----
 
+// CHECK-LABEL: @isvectorized
+llvm.func @isvectorized() {
+  // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]]
+  llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation<isVectorized = true>}
+^bb1:
+  llvm.return
+}
+
+// CHECK: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], !{{[0-9]+}}}
+// CHECK-DAG: ![[VEC_NODE0:[0-9]+]] = !{!"llvm.loop.isvectorized", i32 1}
+
+// -----
 
 #followup = #llvm.loop_annotation<disableNonforced = true>
 
@@ -73,7 +85,7 @@ llvm.func @unrollOptions() {
   // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]]
   llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation<unroll = <
     disable = true, count = 64 : i32, runtimeDisable = false, full = false,
-    followup = #followup, followupRemainder = #followup>
+    followupUnrolled = #followup, followupRemainder = #followup, followupAll = #followup>
   >}
 ^bb1:
   llvm.return
@@ -81,12 +93,13 @@ llvm.func @unrollOptions() {
 
 // CHECK-DAG: ![[NON_FORCED:[0-9]+]] = !{!"llvm.loop.disable_nonforced"}
 // CHECK-DAG: ![[FOLLOWUP:[0-9]+]] = distinct !{![[FOLLOWUP]], ![[NON_FORCED]]}
-// CHECK-DAG: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}
+// CHECK-DAG: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}}
 // CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.disable"}
 // CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.count", i32 64}
 // CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.runtime.disable", i1 false}
-// CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.followup", ![[FOLLOWUP]]}
+// CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.followup_unrolled", ![[FOLLOWUP]]}
 // CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.followup_remainder", ![[FOLLOWUP]]}
+// CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.followup_all", ![[FOLLOWUP]]}
 
 // -----