//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
return success();
}
};
+
+/// Slice elements from `values` into `outValues`. `counts` represents the
+/// numbers of elements to stride in the original values for each dimension.
+/// The output values can be used to construct a DenseElementsAttr.
+template <typename IterTy, typename ElemTy>
+static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
+ ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides,
+ llvm::SmallVectorImpl<ElemTy> *outValues) {
+ assert(offsets.size() == sizes.size());
+ assert(offsets.size() == strides.size());
+ if (offsets.empty())
+ return;
+
+ int64_t offset = offsets.front();
+ int64_t size = sizes.front();
+ int64_t stride = strides.front();
+ if (offsets.size() == 1) {
+ for (int64_t i = 0; i < size; ++i, offset += stride)
+ outValues->push_back(*(values + offset));
+
+ return;
+ }
+
+ for (int64_t i = 0; i < size; ++i, offset += stride) {
+ auto begin = values + offset * counts.front();
+ sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
+ offsets.drop_front(), sizes.drop_front(),
+ strides.drop_front(), outValues);
+ }
+}
+
+/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
+/// operation might introduce more constant data; Users can control their
+/// heuristics by the control function.
+class ConstantOpExtractSliceFolder final
+ : public OpRewritePattern<ExtractSliceOp> {
+public:
+ using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
+
+ ConstantOpExtractSliceFolder(MLIRContext *context,
+ ControlConstantExtractSliceFusionFn controlFn)
+ : OpRewritePattern<ExtractSliceOp>(context),
+ controlFn(std::move(controlFn)) {}
+
+ LogicalResult matchAndRewrite(ExtractSliceOp op,
+ PatternRewriter &rewriter) const override {
+ DenseElementsAttr attr;
+ if (!matchPattern(op.source(), m_Constant(&attr)))
+ return failure();
+
+ // A constant splat is handled by fold().
+ if (attr.isSplat())
+ return failure();
+
+ // Dynamic result shape is not supported.
+ auto sourceType = op.source().getType().cast<ShapedType>();
+ auto resultType = op.result().getType().cast<ShapedType>();
+ if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
+ return failure();
+
+ // Customized control over the folding.
+ if (!controlFn(op))
+ return failure();
+
+ int64_t count = sourceType.getNumElements();
+ if (count == 0)
+ return failure();
+
+ // Check if there are any dynamic parts, which are not supported.
+ auto offsets = extractFromI64ArrayAttr(op.static_offsets());
+ if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
+ return failure();
+ auto sizes = extractFromI64ArrayAttr(op.static_sizes());
+ if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
+ return failure();
+ auto strides = extractFromI64ArrayAttr(op.static_strides());
+ if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
+ return failure();
+
+ // Compute the stride for each dimension.
+ SmallVector<int64_t> counts;
+ ArrayRef<int64_t> shape = sourceType.getShape();
+ counts.reserve(shape.size());
+ for (int64_t v : shape) {
+ count = count / v;
+ counts.push_back(count);
+ }
+
+ // New attribute constructed by the sliced values.
+ DenseElementsAttr newAttr;
+
+ if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
+ SmallVector<APInt> outValues;
+ outValues.reserve(sourceType.getNumElements());
+ sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
+ elems.begin(), counts, offsets, sizes, strides, &outValues);
+ newAttr = DenseElementsAttr::get(resultType, outValues);
+ } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
+ SmallVector<APFloat> outValues;
+ outValues.reserve(sourceType.getNumElements());
+ sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
+ elems.begin(), counts, offsets, sizes, strides, &outValues);
+ newAttr = DenseElementsAttr::get(resultType, outValues);
+ }
+
+ if (newAttr) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
+ return success();
+ }
+
+ return failure();
+ }
+
+private:
+ /// This additionally controls whether the fold happens or not. Users can
+ /// impose their heuristics in the function.
+ ControlConstantExtractSliceFusionFn controlFn;
+};
+
} // namespace
+void mlir::tensor::populateFoldConstantExtractSlicePatterns(
+ RewritePatternSet &patterns,
+ const ControlConstantExtractSliceFusionFn &controlFn) {
+ patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
+}
+
/// Return the canonical type of the result of an extract_slice op.
struct SliceReturnTypeCanonicalizer {
RankedTensorType operator()(ExtractSliceOp op,
return this->source();
if (Value slice = foldExtractAfterInsertSlice(*this))
return slice;
+
return OpFoldResult();
}
--- /dev/null
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s
+
+// CHECK-LABEL: func @slice_constant
+// CHECK-NOT: tensor.extract_slice
+// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32>
+// CHECK: return %[[CONST]] : tensor<1x1xf32>
+func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32>
+{
+ %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32>
+ %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32>
+ return %slice : tensor<1x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @slice_constant_3x4
+// CHECK-NOT: tensor.extract_slice
+// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32>
+// CHECK: return %[[CONST]] : tensor<2x2xf32>
+func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
+{
+ %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
+ %slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
+ return %slice : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @slice_constant_3x4_offsets
+// CHECK-NOT: tensor.extract_slice
+// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32>
+// CHECK: return %[[CONST]] : tensor<2x2xf32>
+func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
+{
+ %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
+ %slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
+ return %slice : tensor<2x2xf32>
+}
+
*this, "test-split-padding-patterns",
llvm::cl::desc("Test patterns to split tensor.pad ops"),
llvm::cl::init(false)};
+
+ Option<bool> testFoldConstantExtractSlice{
+ *this, "test-fold-constant-extract-slice",
+ llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
+ llvm::cl::init(false)};
};
} // namespace
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ tensor::ControlConstantExtractSliceFusionFn controlFn =
+ [](tensor::ExtractSliceOp op) {
+ if (!op.source().hasOneUse())
+ return false;
+
+ auto resultType = op.result().getType().cast<ShapedType>();
+ constexpr int64_t kConstantFoldingMaxNumElements = 1024;
+ if (resultType.getNumElements() > kConstantFoldingMaxNumElements)
+ return false;
+
+ return true;
+ };
+
+ tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
void TestTensorTransforms::runOnOperation() {
FuncOp func = getOperation();
if (testSplitPaddingPatterns)
applySplitPaddingPatterns(func);
+ if (testFoldConstantExtractSlice)
+ applyFoldConstantExtractSlicePatterns(func);
}
namespace mlir {