[mlir][Linalg] NFC - Expose getSmallestBoundingIndex as an utility function
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 20 Jan 2021 19:24:49 +0000 (19:24 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 20 Jan 2021 19:53:09 +0000 (19:53 +0000)
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp

index f194209..1a7dc93 100644 (file)
@@ -142,6 +142,12 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
   inVec = auxVec;
 }
 
+/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
+/// is a constant then return a new value set to the smallest such constant.
+/// If `size` comes from a ConstantOp, return the constant.
+/// Otherwise return nullptr.
+IntegerAttr getSmallestBoundingIndex(Value size);
+
 /// Scheme used to distribute loops to processors.
 enum class DistributionMethod {
   /// Cyclic distribution where no assumption is made about the dynamic
index 329cc88..4d314c3 100644 (file)
@@ -44,27 +44,6 @@ using folded_std_view = FoldedValueBuilder<ViewOp>;
 
 #define DEBUG_TYPE "linalg-promotion"
 
-/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
-/// is a constant then return a new value set to the smallest such constant.
-/// Otherwise return size.
-static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
-                                                 Value size) {
-  Optional<int64_t> boundingConst = {};
-  if (auto affineMinOp = size.getDefiningOp<AffineMinOp>()) {
-    for (auto e : affineMinOp.getAffineMap().getResults())
-      if (auto cst = e.dyn_cast<AffineConstantExpr>())
-        boundingConst = boundingConst
-                            ? std::min(boundingConst.getValue(), cst.getValue())
-                            : cst.getValue();
-  } else if (auto constIndexOp = size.getDefiningOp<ConstantOp>()) {
-    if (constIndexOp.getType().isa<IndexType>())
-      boundingConst = constIndexOp.value().cast<IntegerAttr>().getInt();
-  }
-  return boundingConst && *boundingConst >= 0
-             ? b.create<ConstantIndexOp>(loc, *boundingConst)
-             : size;
-}
-
 /// Alloc a new buffer of `size`. If `dynamicBuffers` is true allocate exactly
 /// the size needed, otherwise try to allocate a static bounding box.
 static Value allocBuffer(const LinalgPromotionOptions &options,
@@ -242,7 +221,9 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
     auto rangeValue = en.value();
     // Try to extract a tight constant.
     LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
-    Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
+    IntegerAttr sizeAttr = getSmallestBoundingIndex(rangeValue.size);
+    Value size =
+        (!sizeAttr) ? rangeValue.size : b.create<ConstantOp>(loc, sizeAttr);
     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
     fullSizes.push_back(size);
     partialSizes.push_back(folded_std_dim(folder, subView, en.index()));
index 81bfbc6..32b32be 100644 (file)
@@ -115,6 +115,31 @@ Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
   return invertedMap.compose(viewSizes);
 }
 
+/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
+/// is a constant then return a new value set to the smallest such constant.
+/// Otherwise returngetSmallestBoundingIndex nullptr.
+IntegerAttr getSmallestBoundingIndex(Value size) {
+  Optional<int64_t> boundingConst = {};
+  if (auto affineMinOp = size.getDefiningOp<AffineMinOp>()) {
+    for (auto e : affineMinOp.getAffineMap().getResults())
+      if (auto cst = e.dyn_cast<AffineConstantExpr>())
+        boundingConst = boundingConst
+                            ? std::min(boundingConst.getValue(), cst.getValue())
+                            : cst.getValue();
+  } else if (auto constIndexOp = size.getDefiningOp<ConstantOp>()) {
+    if (constIndexOp.getType().isa<IndexType>())
+      boundingConst = constIndexOp.value().cast<IntegerAttr>().getInt();
+  } else if (auto affineApplyOp = size.getDefiningOp<AffineApplyOp>()) {
+    if (auto cExpr = affineApplyOp.getAffineMap()
+                         .getResult(0)
+                         .dyn_cast<AffineConstantExpr>())
+      boundingConst = cExpr.getValue();
+  }
+  if (boundingConst && *boundingConst >= 0)
+    return Builder(size.getContext()).getIndexAttr(*boundingConst);
+  return nullptr;
+}
+
 /// Specialization to build an scf "for" nest.
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(