//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
+using namespace mlir::vector;
template <typename T>
static LLVM::LLVMType getPtrToElementType(T containerType,
}
};
+// TODO(rriddle): Better support for attribute subtype forwarding + slicing.
+static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
+ unsigned dropFront = 0,
+ unsigned dropBack = 0) {
+ assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
+ auto range = arrayAttr.getAsRange<IntegerAttr>();
+ SmallVector<int64_t, 4> res;
+ res.reserve(arrayAttr.size() - dropFront - dropBack);
+ for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
+ it != eit; ++it)
+ res.push_back((*it).getValue().getSExtValue());
+ return res;
+}
+
+/// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank
+/// of `vector`.
+static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
+ int64_t offset) {
+ auto vectorType = vector.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<ExtractOp>(loc, vector, offset);
+ return rewriter.create<vector::ExtractElementOp>(
+ loc, vectorType.getElementType(), vector,
+ rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
+/// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank
+/// of `vector`.
+static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
+ Value into, int64_t offset) {
+ auto vectorType = into.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<InsertOp>(loc, from, into, offset);
+ return rewriter.create<vector::InsertElementOp>(
+ loc, vectorType, from, into,
+ rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
+/// Progressive lowering of StridedSliceOp to either:
+/// 1. extractelement + insertelement for the 1-D case
+/// 2. extract + optional strided_slice + insert for the n-D case.
+class VectorStridedSliceOpRewritePattern
+ : public OpRewritePattern<StridedSliceOp> {
+public:
+ using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(StridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto dstType = op.getResult().getType().cast<VectorType>();
+
+ assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
+
+ int64_t offset =
+ op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t stride =
+ op.strides().getValue().front().cast<IntegerAttr>().getInt();
+
+ auto loc = op.getLoc();
+ auto elemType = dstType.getElementType();
+ assert(elemType.isIntOrIndexOrFloat());
+ Value zero = rewriter.create<ConstantOp>(loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value res = rewriter.create<SplatOp>(loc, dstType, zero);
+ for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
+ off += stride, ++idx) {
+ Value extracted = extractOne(rewriter, loc, op.vector(), off);
+ if (op.offsets().getValue().size() > 1) {
+ StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>(
+ loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
+ getI64SubArray(op.sizes(), /* dropFront=*/1),
+ getI64SubArray(op.strides(), /* dropFront=*/1));
+ // Call matchAndRewrite recursively from within the pattern. This
+ // circumvents the current limitation that a given pattern cannot
+ // be called multiple times by the PatternRewrite infrastructure (to
+ // avoid infinite recursion, but in this case, infinite recursion
+ // cannot happen because the rank is strictly decreasing).
+ // TODO(rriddle, nicolasvasilache) Implement something like a hook for
+ // a potential function that must decrease and allow the same pattern
+ // multiple times.
+ auto success = matchAndRewrite(stridedSliceOp, rewriter);
+ (void)success;
+ assert(success && "Unexpected failure");
+ extracted = stridedSliceOp;
+ }
+ res = insertOne(rewriter, loc, extracted, res, idx);
+ }
+ rewriter.replaceOp(op, {res});
+ return matchSuccess();
+ }
+};
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ MLIRContext *ctx = converter.getDialect()->getContext();
+ patterns.insert<VectorStridedSliceOpRewritePattern>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion,
- VectorPrintOpConversion>(converter.getDialect()->getContext(),
- converter);
+ VectorPrintOpConversion>(ctx, converter);
}
namespace {
// CHECK: llvm.call @print_close() : () -> ()
// CHECK: llvm.call @print_close() : () -> ()
// CHECK: llvm.call @print_newline() : () -> ()
+
+
+func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) {
+// CHECK-LABEL: llvm.func @strided_slice(
+
+ %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+
+ %1 = vector.strided_slice %arg1 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm<"[2 x <8 x float>]">
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"[2 x <8 x float>]">
+// CHECK: llvm.extractvalue %{{.*}}[3] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"[2 x <8 x float>]">
+
+ %2 = vector.strided_slice %arg1 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm<"[2 x <2 x float>]">
+//
+// Subvector vector<8xf32> @2
+// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <2 x float>]">
+//
+// Subvector vector<8xf32> @3
+// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <2 x float>]">
+
+ return
+}
+
+