[mlir][NFC] Move asOpFoldResult helper functions to StaticValueUtils
authorMatthias Springer <springerm@google.com>
Thu, 15 Jul 2021 01:28:25 +0000 (10:28 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 15 Jul 2021 01:28:57 +0000 (10:28 +0900)
Differential Revision: https://reviews.llvm.org/D105602

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp

index 3284c02..5838f1d 100644 (file)
@@ -44,6 +44,14 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
 SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
 
+/// Given a value, try to extract a constant Attribute. If this fails, return
+/// the original value.
+OpFoldResult getAsOpFoldResult(Value val);
+
+/// Given an array of values, try to extract a constant Attribute from each
+/// value. If this fails, return the original value.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
+
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
 
index 32efdc2..12fbb8c 100644 (file)
@@ -89,20 +89,6 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
 
-/// Helper function to convert a Value into an OpFoldResult, if the Value is
-/// known to be a constant index value.
-static SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
-  return llvm::to_vector<4>(
-      llvm::map_range(values, [](Value v) -> OpFoldResult {
-        APInt intValue;
-        if (v.getType().isa<IndexType>() &&
-            matchPattern(v, m_ConstantInt(&intValue))) {
-          return IntegerAttr::get(v.getType(), intValue.getSExtValue());
-        }
-        return v;
-      }));
-}
-
 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
 /// `Value`s.
 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
index 59a858f..4c3bb41 100644 (file)
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
@@ -798,14 +799,6 @@ static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) {
   return builder.create<ConstantIndexOp>(loc, *intVal);
 }
 
-/// Given a value, try to extract a constant index-type integer as an Attribute.
-/// If this fails, return the original value.
-static OpFoldResult asOpFoldResult(OpBuilder &builder, Value val) {
-  if (auto constInt = getConstantIntValue(val))
-    return builder.getIndexAttr(*constInt);
-  return val;
-}
-
 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
@@ -895,7 +888,7 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     // ExtractSliceOp length will be zero in that case. (Effectively reading no
     // data from the source.)
     Value newOffset = min(max(sub(offset, low), zero), srcSize);
-    newOffsets.push_back(asOpFoldResult(rewriter, newOffset));
+    newOffsets.push_back(getAsOpFoldResult(newOffset));
 
     // The original ExtractSliceOp was reading until position `offset + length`.
     // Therefore, the corresponding position within the source tensor is:
@@ -915,7 +908,7 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     // The new ExtractSliceOp length is `endLoc - newOffset`.
     Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize);
     Value newLength = sub(endLoc, newOffset);
-    newLengths.push_back(asOpFoldResult(rewriter, newLength));
+    newLengths.push_back(getAsOpFoldResult(newLength));
 
     // Check if newLength is zero. In that case, no SubTensorOp should be
     // executed.
index bf7d662..b6cb4b7 100644 (file)
@@ -47,6 +47,22 @@ SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
       }));
 }
 
+/// Given a value, try to extract a constant Attribute. If this fails, return
+/// the original value.
+OpFoldResult getAsOpFoldResult(Value val) {
+  Attribute attr;
+  if (matchPattern(val, m_Constant(&attr)))
+    return attr;
+  return val;
+}
+
+/// Given an array of values, try to extract a constant Attribute from each
+/// value. If this fails, return the original value.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
+  return llvm::to_vector<4>(
+      llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
+}
+
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   // Case 1: Check for Constant integer.