[mlir] Fold Arithmetic::ConstantOp and Tensor::ExtractSliceOp.
authorOkwan Kwon <okkwon@gmail.com>
Tue, 22 Feb 2022 21:50:14 +0000 (21:50 +0000)
committerOkwan Kwon <okkwon@gmail.com>
Mon, 28 Feb 2022 17:47:29 +0000 (17:47 +0000)
Fold ExtractSliceOp when the source is a constant.

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

index e6267e9..a4f7875 100644 (file)
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
@@ -20,6 +21,19 @@ namespace tensor {
 void populateSplitPaddingPatterns(RewritePatternSet &patterns,
                                   PatternBenefit baseBenefit = 1);
 
+/// Function to control the folding of constant and extract slice
+using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
+
+/// Patterns to fold the extract slice op with its constant operand
+void populateFoldConstantExtractSlicePatterns(
+    RewritePatternSet &patterns,
+    const ControlConstantExtractSliceFusionFn &controlFn =
+        [](ExtractSliceOp op) {
+          // Disable by default because the folding can generate a large
+          // constant tensor, which would affect the compile time and storage.
+          return false;
+        });
+
 } // namespace tensor
 } // namespace mlir
 
index 70aa7b5..d6a4bb4 100644 (file)
@@ -6,17 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -1158,8 +1155,134 @@ public:
     return success();
   }
 };
+
+/// Slice elements from `values` into `outValues`. `counts` represents the
+/// numbers of elements to stride in the original values for each dimension.
+/// The output values can be used to construct a DenseElementsAttr.
+template <typename IterTy, typename ElemTy>
+static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
+                          ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
+                          ArrayRef<int64_t> strides,
+                          llvm::SmallVectorImpl<ElemTy> *outValues) {
+  assert(offsets.size() == sizes.size());
+  assert(offsets.size() == strides.size());
+  if (offsets.empty())
+    return;
+
+  int64_t offset = offsets.front();
+  int64_t size = sizes.front();
+  int64_t stride = strides.front();
+  if (offsets.size() == 1) {
+    for (int64_t i = 0; i < size; ++i, offset += stride)
+      outValues->push_back(*(values + offset));
+
+    return;
+  }
+
+  for (int64_t i = 0; i < size; ++i, offset += stride) {
+    auto begin = values + offset * counts.front();
+    sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
+                                  offsets.drop_front(), sizes.drop_front(),
+                                  strides.drop_front(), outValues);
+  }
+}
+
+/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
+/// operation might introduce more constant data; Users can control their
+/// heuristics by the control function.
+class ConstantOpExtractSliceFolder final
+    : public OpRewritePattern<ExtractSliceOp> {
+public:
+  using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
+
+  ConstantOpExtractSliceFolder(MLIRContext *context,
+                               ControlConstantExtractSliceFusionFn controlFn)
+      : OpRewritePattern<ExtractSliceOp>(context),
+        controlFn(std::move(controlFn)) {}
+
+  LogicalResult matchAndRewrite(ExtractSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    DenseElementsAttr attr;
+    if (!matchPattern(op.source(), m_Constant(&attr)))
+      return failure();
+
+    // A constant splat is handled by fold().
+    if (attr.isSplat())
+      return failure();
+
+    // Dynamic result shape is not supported.
+    auto sourceType = op.source().getType().cast<ShapedType>();
+    auto resultType = op.result().getType().cast<ShapedType>();
+    if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
+      return failure();
+
+    // Customized control over the folding.
+    if (!controlFn(op))
+      return failure();
+
+    int64_t count = sourceType.getNumElements();
+    if (count == 0)
+      return failure();
+
+    // Check if there are any dynamic parts, which are not supported.
+    auto offsets = extractFromI64ArrayAttr(op.static_offsets());
+    if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
+      return failure();
+    auto sizes = extractFromI64ArrayAttr(op.static_sizes());
+    if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
+      return failure();
+    auto strides = extractFromI64ArrayAttr(op.static_strides());
+    if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
+      return failure();
+
+    // Compute the stride for each dimension.
+    SmallVector<int64_t> counts;
+    ArrayRef<int64_t> shape = sourceType.getShape();
+    counts.reserve(shape.size());
+    for (int64_t v : shape) {
+      count = count / v;
+      counts.push_back(count);
+    }
+
+    // New attribute constructed by the sliced values.
+    DenseElementsAttr newAttr;
+
+    if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
+      SmallVector<APInt> outValues;
+      outValues.reserve(sourceType.getNumElements());
+      sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
+          elems.begin(), counts, offsets, sizes, strides, &outValues);
+      newAttr = DenseElementsAttr::get(resultType, outValues);
+    } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
+      SmallVector<APFloat> outValues;
+      outValues.reserve(sourceType.getNumElements());
+      sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
+          elems.begin(), counts, offsets, sizes, strides, &outValues);
+      newAttr = DenseElementsAttr::get(resultType, outValues);
+    }
+
+    if (newAttr) {
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
+      return success();
+    }
+
+    return failure();
+  }
+
+private:
+  /// This additionally controls whether the fold happens or not. Users can
+  /// impose their heuristics in the function.
+  ControlConstantExtractSliceFusionFn controlFn;
+};
+
 } // namespace
 
+void mlir::tensor::populateFoldConstantExtractSlicePatterns(
+    RewritePatternSet &patterns,
+    const ControlConstantExtractSliceFusionFn &controlFn) {
+  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
+}
+
 /// Return the canonical type of the result of an extract_slice op.
 struct SliceReturnTypeCanonicalizer {
   RankedTensorType operator()(ExtractSliceOp op,
@@ -1238,6 +1361,7 @@ OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
     return this->source();
   if (Value slice = foldExtractAfterInsertSlice(*this))
     return slice;
+
   return OpFoldResult();
 }
 
diff --git a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
new file mode 100644 (file)
index 0000000..03c6195
--- /dev/null
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s
+
+// CHECK-LABEL: func @slice_constant
+//   CHECK-NOT:   tensor.extract_slice
+//       CHECK:   %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32>
+//       CHECK:   return %[[CONST]] :  tensor<1x1xf32>
+func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32>
+{
+  %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32>
+  %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32>
+  return %slice : tensor<1x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @slice_constant_3x4
+//   CHECK-NOT:   tensor.extract_slice
+//       CHECK:   %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32>
+//       CHECK:   return %[[CONST]] :  tensor<2x2xf32>
+func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
+{
+  %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
+  %slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
+  return %slice : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @slice_constant_3x4_offsets
+//   CHECK-NOT:   tensor.extract_slice
+//       CHECK:   %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32>
+//       CHECK:   return %[[CONST]] :  tensor<2x2xf32>
+func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
+{
+  %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
+  %slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
+  return %slice : tensor<2x2xf32>
+}
+
index c720ca1..4d947ef 100644 (file)
@@ -41,6 +41,11 @@ struct TestTensorTransforms
       *this, "test-split-padding-patterns",
       llvm::cl::desc("Test patterns to split tensor.pad ops"),
       llvm::cl::init(false)};
+
+  Option<bool> testFoldConstantExtractSlice{
+      *this, "test-fold-constant-extract-slice",
+      llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -50,10 +55,31 @@ static void applySplitPaddingPatterns(FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  tensor::ControlConstantExtractSliceFusionFn controlFn =
+      [](tensor::ExtractSliceOp op) {
+        if (!op.source().hasOneUse())
+          return false;
+
+        auto resultType = op.result().getType().cast<ShapedType>();
+        constexpr int64_t kConstantFoldingMaxNumElements = 1024;
+        if (resultType.getNumElements() > kConstantFoldingMaxNumElements)
+          return false;
+
+        return true;
+      };
+
+  tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 void TestTensorTransforms::runOnOperation() {
   FuncOp func = getOperation();
   if (testSplitPaddingPatterns)
     applySplitPaddingPatterns(func);
+  if (testFoldConstantExtractSlice)
+    applyFoldConstantExtractSlicePatterns(func);
 }
 
 namespace mlir {