//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
return success();
}
};
-
-/// Slice elements from `values` into `outValues`. `counts` represents the
-/// numbers of elements to stride in the original values for each dimension.
-/// The output values can be used to construct a DenseElementsAttr.
-template <typename IterTy, typename ElemTy>
-static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
- ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides,
- llvm::SmallVectorImpl<ElemTy> *outValues) {
- assert(offsets.size() == sizes.size());
- assert(offsets.size() == strides.size());
- if (offsets.empty())
- return;
-
- int64_t offset = offsets.front();
- int64_t size = sizes.front();
- int64_t stride = strides.front();
- if (offsets.size() == 1) {
- for (int64_t i = 0; i < size; ++i, offset += stride)
- outValues->push_back(*(values + offset));
-
- return;
- }
-
- for (int64_t i = 0; i < size; ++i, offset += stride) {
- auto begin = values + offset * counts.front();
- sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
- offsets.drop_front(), sizes.drop_front(),
- strides.drop_front(), outValues);
- }
-}
-
-/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
-/// operation might introduce more constant data; Users can control their
-/// heuristics by the control function.
-class ConstantOpExtractSliceFolder final
- : public OpRewritePattern<ExtractSliceOp> {
-public:
- using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
-
- ConstantOpExtractSliceFolder(MLIRContext *context,
- ControlConstantExtractSliceFusionFn controlFn)
- : OpRewritePattern<ExtractSliceOp>(context),
- controlFn(std::move(controlFn)) {}
-
- LogicalResult matchAndRewrite(ExtractSliceOp op,
- PatternRewriter &rewriter) const override {
- DenseElementsAttr attr;
- if (!matchPattern(op.source(), m_Constant(&attr)))
- return failure();
-
- // A constant splat is handled by fold().
- if (attr.isSplat())
- return failure();
-
- // Dynamic result shape is not supported.
- auto sourceType = op.source().getType().cast<ShapedType>();
- auto resultType = op.result().getType().cast<ShapedType>();
- if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
- return failure();
-
- // Customized control over the folding.
- if (!controlFn(op))
- return failure();
-
- int64_t count = sourceType.getNumElements();
- if (count == 0)
- return failure();
-
- // Check if there are any dynamic parts, which are not supported.
- auto offsets = extractFromI64ArrayAttr(op.static_offsets());
- if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
- return failure();
- auto sizes = extractFromI64ArrayAttr(op.static_sizes());
- if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
- return failure();
- auto strides = extractFromI64ArrayAttr(op.static_strides());
- if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
- return failure();
-
- // Compute the stride for each dimension.
- SmallVector<int64_t> counts;
- ArrayRef<int64_t> shape = sourceType.getShape();
- counts.reserve(shape.size());
- for (int64_t v : shape) {
- count = count / v;
- counts.push_back(count);
- }
-
- // New attribute constructed by the sliced values.
- DenseElementsAttr newAttr;
-
- if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
- SmallVector<APInt> outValues;
- outValues.reserve(sourceType.getNumElements());
- sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
- elems.begin(), counts, offsets, sizes, strides, &outValues);
- newAttr = DenseElementsAttr::get(resultType, outValues);
- } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
- SmallVector<APFloat> outValues;
- outValues.reserve(sourceType.getNumElements());
- sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
- elems.begin(), counts, offsets, sizes, strides, &outValues);
- newAttr = DenseElementsAttr::get(resultType, outValues);
- }
-
- if (newAttr) {
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
- return success();
- }
-
- return failure();
- }
-
-private:
- /// This additionally controls whether the fold happens or not. Users can
- /// impose their heuristics in the function.
- ControlConstantExtractSliceFusionFn controlFn;
-};
-
} // namespace
-void mlir::tensor::populateFoldConstantExtractSlicePatterns(
- RewritePatternSet &patterns,
- const ControlConstantExtractSliceFusionFn &controlFn) {
- patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
-}
-
/// Return the canonical type of the result of an extract_slice op.
struct SliceReturnTypeCanonicalizer {
RankedTensorType operator()(ExtractSliceOp op,
return this->source();
if (Value slice = foldExtractAfterInsertSlice(*this))
return slice;
-
return OpFoldResult();
}
+++ /dev/null
-// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s
-
-// CHECK-LABEL: func @slice_constant
-// CHECK-NOT: tensor.extract_slice
-// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32>
-// CHECK: return %[[CONST]] : tensor<1x1xf32>
-func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32>
-{
- %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32>
- %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32>
- return %slice : tensor<1x1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @slice_constant_3x4
-// CHECK-NOT: tensor.extract_slice
-// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32>
-// CHECK: return %[[CONST]] : tensor<2x2xf32>
-func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
-{
- %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
- %slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
- return %slice : tensor<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @slice_constant_3x4_offsets
-// CHECK-NOT: tensor.extract_slice
-// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32>
-// CHECK: return %[[CONST]] : tensor<2x2xf32>
-func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
-{
- %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
- %slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
- return %slice : tensor<2x2xf32>
-}
-
*this, "test-split-padding-patterns",
llvm::cl::desc("Test patterns to split tensor.pad ops"),
llvm::cl::init(false)};
-
- Option<bool> testFoldConstantExtractSlice{
- *this, "test-fold-constant-extract-slice",
- llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
- llvm::cl::init(false)};
};
} // namespace
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
-static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) {
- RewritePatternSet patterns(funcOp.getContext());
- tensor::ControlConstantExtractSliceFusionFn controlFn =
- [](tensor::ExtractSliceOp op) {
- if (!op.source().hasOneUse())
- return false;
-
- auto resultType = op.result().getType().cast<ShapedType>();
- constexpr int64_t kConstantFoldingMaxNumElements = 1024;
- if (resultType.getNumElements() > kConstantFoldingMaxNumElements)
- return false;
-
- return true;
- };
-
- tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
void TestTensorTransforms::runOnOperation() {
FuncOp func = getOperation();
if (testSplitPaddingPatterns)
applySplitPaddingPatterns(func);
- if (testFoldConstantExtractSlice)
- applyFoldConstantExtractSlicePatterns(func);
}
namespace mlir {