Revert "[mlir] Fold Arithmetic::ConstantOp and Tensor::ExtractSliceOp."
authorOkwan Kwon <okkwon@gmail.com>
Mon, 28 Feb 2022 19:11:29 +0000 (19:11 +0000)
committerOkwan Kwon <okkwon@gmail.com>
Mon, 28 Feb 2022 19:14:05 +0000 (19:14 +0000)
This reverts commit 3104994104f0c2f274acf5e01eb6cc82e9cca06b.

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir [deleted file]
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

index a4f7875..e6267e9 100644 (file)
@@ -9,7 +9,6 @@
 #ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
 
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
@@ -21,19 +20,6 @@ namespace tensor {
 void populateSplitPaddingPatterns(RewritePatternSet &patterns,
                                   PatternBenefit baseBenefit = 1);
 
-/// Function to control the folding of constant and extract slice
-using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
-
-/// Patterns to fold the extract slice op with its constant operand
-void populateFoldConstantExtractSlicePatterns(
-    RewritePatternSet &patterns,
-    const ControlConstantExtractSliceFusionFn &controlFn =
-        [](ExtractSliceOp op) {
-          // Disable by default because the folding can generate a large
-          // constant tensor, which would affect the compile time and storage.
-          return false;
-        });
-
 } // namespace tensor
 } // namespace mlir
 
index d6a4bb4..70aa7b5 100644 (file)
@@ -6,14 +6,17 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -1155,134 +1158,8 @@ public:
     return success();
   }
 };
-
-/// Slice elements from `values` into `outValues`. `counts` represents the
-/// numbers of elements to stride in the original values for each dimension.
-/// The output values can be used to construct a DenseElementsAttr.
-template <typename IterTy, typename ElemTy>
-static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
-                          ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
-                          ArrayRef<int64_t> strides,
-                          llvm::SmallVectorImpl<ElemTy> *outValues) {
-  assert(offsets.size() == sizes.size());
-  assert(offsets.size() == strides.size());
-  if (offsets.empty())
-    return;
-
-  int64_t offset = offsets.front();
-  int64_t size = sizes.front();
-  int64_t stride = strides.front();
-  if (offsets.size() == 1) {
-    for (int64_t i = 0; i < size; ++i, offset += stride)
-      outValues->push_back(*(values + offset));
-
-    return;
-  }
-
-  for (int64_t i = 0; i < size; ++i, offset += stride) {
-    auto begin = values + offset * counts.front();
-    sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
-                                  offsets.drop_front(), sizes.drop_front(),
-                                  strides.drop_front(), outValues);
-  }
-}
-
-/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
-/// operation might introduce more constant data; Users can control their
-/// heuristics by the control function.
-class ConstantOpExtractSliceFolder final
-    : public OpRewritePattern<ExtractSliceOp> {
-public:
-  using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
-
-  ConstantOpExtractSliceFolder(MLIRContext *context,
-                               ControlConstantExtractSliceFusionFn controlFn)
-      : OpRewritePattern<ExtractSliceOp>(context),
-        controlFn(std::move(controlFn)) {}
-
-  LogicalResult matchAndRewrite(ExtractSliceOp op,
-                                PatternRewriter &rewriter) const override {
-    DenseElementsAttr attr;
-    if (!matchPattern(op.source(), m_Constant(&attr)))
-      return failure();
-
-    // A constant splat is handled by fold().
-    if (attr.isSplat())
-      return failure();
-
-    // Dynamic result shape is not supported.
-    auto sourceType = op.source().getType().cast<ShapedType>();
-    auto resultType = op.result().getType().cast<ShapedType>();
-    if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
-      return failure();
-
-    // Customized control over the folding.
-    if (!controlFn(op))
-      return failure();
-
-    int64_t count = sourceType.getNumElements();
-    if (count == 0)
-      return failure();
-
-    // Check if there are any dynamic parts, which are not supported.
-    auto offsets = extractFromI64ArrayAttr(op.static_offsets());
-    if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
-      return failure();
-    auto sizes = extractFromI64ArrayAttr(op.static_sizes());
-    if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
-      return failure();
-    auto strides = extractFromI64ArrayAttr(op.static_strides());
-    if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
-      return failure();
-
-    // Compute the stride for each dimension.
-    SmallVector<int64_t> counts;
-    ArrayRef<int64_t> shape = sourceType.getShape();
-    counts.reserve(shape.size());
-    for (int64_t v : shape) {
-      count = count / v;
-      counts.push_back(count);
-    }
-
-    // New attribute constructed by the sliced values.
-    DenseElementsAttr newAttr;
-
-    if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
-      SmallVector<APInt> outValues;
-      outValues.reserve(sourceType.getNumElements());
-      sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
-          elems.begin(), counts, offsets, sizes, strides, &outValues);
-      newAttr = DenseElementsAttr::get(resultType, outValues);
-    } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
-      SmallVector<APFloat> outValues;
-      outValues.reserve(sourceType.getNumElements());
-      sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
-          elems.begin(), counts, offsets, sizes, strides, &outValues);
-      newAttr = DenseElementsAttr::get(resultType, outValues);
-    }
-
-    if (newAttr) {
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
-      return success();
-    }
-
-    return failure();
-  }
-
-private:
-  /// This additionally controls whether the fold happens or not. Users can
-  /// impose their heuristics in the function.
-  ControlConstantExtractSliceFusionFn controlFn;
-};
-
 } // namespace
 
-void mlir::tensor::populateFoldConstantExtractSlicePatterns(
-    RewritePatternSet &patterns,
-    const ControlConstantExtractSliceFusionFn &controlFn) {
-  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
-}
-
 /// Return the canonical type of the result of an extract_slice op.
 struct SliceReturnTypeCanonicalizer {
   RankedTensorType operator()(ExtractSliceOp op,
@@ -1361,7 +1238,6 @@ OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
     return this->source();
   if (Value slice = foldExtractAfterInsertSlice(*this))
     return slice;
-
   return OpFoldResult();
 }
 
diff --git a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
deleted file mode 100644 (file)
index 03c6195..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s
-
-// CHECK-LABEL: func @slice_constant
-//   CHECK-NOT:   tensor.extract_slice
-//       CHECK:   %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32>
-//       CHECK:   return %[[CONST]] :  tensor<1x1xf32>
-func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32>
-{
-  %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32>
-  %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32>
-  return %slice : tensor<1x1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @slice_constant_3x4
-//   CHECK-NOT:   tensor.extract_slice
-//       CHECK:   %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32>
-//       CHECK:   return %[[CONST]] :  tensor<2x2xf32>
-func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
-{
-  %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
-  %slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
-  return %slice : tensor<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @slice_constant_3x4_offsets
-//   CHECK-NOT:   tensor.extract_slice
-//       CHECK:   %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32>
-//       CHECK:   return %[[CONST]] :  tensor<2x2xf32>
-func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
-{
-  %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
-  %slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
-  return %slice : tensor<2x2xf32>
-}
-
index 4d947ef..c720ca1 100644 (file)
@@ -41,11 +41,6 @@ struct TestTensorTransforms
       *this, "test-split-padding-patterns",
       llvm::cl::desc("Test patterns to split tensor.pad ops"),
       llvm::cl::init(false)};
-
-  Option<bool> testFoldConstantExtractSlice{
-      *this, "test-fold-constant-extract-slice",
-      llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
-      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -55,31 +50,10 @@ static void applySplitPaddingPatterns(FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
-static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) {
-  RewritePatternSet patterns(funcOp.getContext());
-  tensor::ControlConstantExtractSliceFusionFn controlFn =
-      [](tensor::ExtractSliceOp op) {
-        if (!op.source().hasOneUse())
-          return false;
-
-        auto resultType = op.result().getType().cast<ShapedType>();
-        constexpr int64_t kConstantFoldingMaxNumElements = 1024;
-        if (resultType.getNumElements() > kConstantFoldingMaxNumElements)
-          return false;
-
-        return true;
-      };
-
-  tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
-  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
 void TestTensorTransforms::runOnOperation() {
   FuncOp func = getOperation();
   if (testSplitPaddingPatterns)
     applySplitPaddingPatterns(func);
-  if (testFoldConstantExtractSlice)
-    applyFoldConstantExtractSlicePatterns(func);
 }
 
 namespace mlir {