void populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
+/// ops into a chain of Extract ops to extract each element from the source, and
+/// then a chain of Insert ops to insert to the target vector.
+///
+/// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
+/// `controlFn` returns true. Otherwise runs on ops.
+void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ RewritePatternSet &patterns,
+ std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
+ PatternBenefit benefit = 1);
+
/// Populate `patterns` with the following patterns.
///
/// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::vector;
}
};
+/// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
+/// to extract each element from the source, and then a chain of Insert ops
+/// to insert to the target vector.
+class Convert1DExtractStridedSliceIntoExtractInsertChain final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ Convert1DExtractStridedSliceIntoExtractInsertChain(
+ MLIRContext *context,
+ std::function<bool(ExtractStridedSliceOp)> controlFn,
+ PatternBenefit benefit)
+ : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ if (controlFn && !controlFn(op))
+ return failure();
+
+ // Only handle 1-D cases.
+ if (op.getOffsets().getValue().size() != 1)
+ return failure();
+
+ int64_t offset =
+ op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size =
+ op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t stride =
+ op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+
+ Location loc = op.getLoc();
+ SmallVector<Value> elements;
+ elements.reserve(size);
+ for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
+ elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(op.getType()));
+ for (int64_t i = 0; i < size; ++i)
+ result = rewriter.create<InsertOp>(loc, elements[i], result, i);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+private:
+ std::function<bool(ExtractStridedSliceOp)> controlFn;
+};
+
/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
}
};
-void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
}
+void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ RewritePatternSet &patterns,
+ std::function<bool(ExtractStridedSliceOp)> controlFn,
+ PatternBenefit benefit) {
+ patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>(
+ patterns.getContext(), std::move(controlFn), benefit);
+}
+
/// Populate the given list with patterns that convert from Vector to LLVM.
-void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
+void vector::populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns, PatternBenefit benefit) {
populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
benefit);
--- /dev/null
+// RUN: mlir-opt -split-input-file -test-vector-extract-strided-slice-lowering %s | FileCheck %s
+
+// CHECK-LABEL: func.func @extract_strided_slice_1D
+// CHECK-SAME: (%[[INPUT:.+]]: vector<8xf16>)
+func.func @extract_strided_slice_1D(%input: vector<8xf16>) -> vector<4xf16> {
+ %0 = vector.extract_strided_slice %input {offsets = [1], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+ return %0: vector<4xf16>
+}
+
+// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf16>
+// CHECK: %[[E0:.+]] = vector.extract %[[INPUT]][1] : vector<8xf16>
+// CHECK: %[[E1:.+]] = vector.extract %[[INPUT]][2] : vector<8xf16>
+// CHECK: %[[E2:.+]] = vector.extract %[[INPUT]][3] : vector<8xf16>
+// CHECK: %[[E3:.+]] = vector.extract %[[INPUT]][4] : vector<8xf16>
+// CHECK: %[[I0:.+]] = vector.insert %[[E0]], %[[INIT]] [0] : f16 into vector<4xf16>
+// CHECK: %[[I1:.+]] = vector.insert %[[E1]], %[[I0]] [1] : f16 into vector<4xf16>
+// CHECK: %[[I2:.+]] = vector.insert %[[E2]], %[[I1]] [2] : f16 into vector<4xf16>
+// CHECK: %[[I3:.+]] = vector.insert %[[E3]], %[[I2]] [3] : f16 into vector<4xf16>
+// CHECK: return %[[I3]]
+
+
+// -----
+
+// CHECK-LABEL: func.func @extract_strided_slice_2D
+func.func @extract_strided_slice_2D(%input: vector<1x8xf16>) -> vector<1x4xf16> {
+ // CHECK: vector.extract_strided_slice
+ %0 = vector.extract_strided_slice %input {offsets = [0, 1], sizes = [1, 4], strides = [1, 1]} : vector<1x8xf16> to vector<1x4xf16>
+ return %0: vector<1x4xf16>
+}
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
}
};
+struct TestVectorExtractStridedSliceLowering
+ : public PassWrapper<TestVectorExtractStridedSliceLowering,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorExtractStridedSliceLowering)
+
+ StringRef getArgument() const final {
+ return "test-vector-extract-strided-slice-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering patterns that converts vector.extract_strided_slice "
+ "into a chain of vector.extract and vector.insert ops";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
} // namespace
namespace mlir {
PassRegistration<TestVectorScanLowering>();
PassRegistration<TestVectorDistribution>();
+
+ PassRegistration<TestVectorExtractStridedSliceLowering>();
}
} // namespace test
} // namespace mlir