[mlir][VectorOps] Implement strided_slice conversion
authorNicolas Vasilache <ntv@google.com>
Thu, 9 Jan 2020 07:58:21 +0000 (02:58 -0500)
committerNicolas Vasilache <ntv@google.com>
Thu, 9 Jan 2020 08:03:51 +0000 (03:03 -0500)
Summary:
This diff implements the progressive lowering of strided_slice to either:
  1. extractelement + insertelement for the 1-D case
  2. extract + optional strided_slice + insert for the n-D case.

This combines properly with the other conversion patterns to lower all the way to LLVM.

Appropriate tests are added.

Reviewers: ftynse, rriddle, AlexEichenberger, andydavis1, tetuante

Reviewed By: andydavis1

Subscribers: merge_guards_bot, mehdi_amini, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72310

mlir/include/mlir/IR/Attributes.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

index b839858..64b8063 100644 (file)
@@ -215,6 +215,25 @@ public:
   static bool kindof(unsigned kind) {
     return kind == StandardAttributes::Array;
   }
+
+private:
+  /// Class for underlying value iterator support.
+  template <typename AttrTy>
+  class attr_value_iterator final
+      : public llvm::mapped_iterator<iterator, AttrTy (*)(Attribute)> {
+  public:
+    explicit attr_value_iterator(iterator it)
+        : llvm::mapped_iterator<iterator, AttrTy (*)(Attribute)>(
+              it, [](Attribute attr) { return attr.cast<AttrTy>(); }) {}
+    AttrTy operator*() { return (*this->I).template cast<AttrTy>(); }
+  };
+
+public:
+  template <typename AttrTy>
+  llvm::iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
+    return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
+                            attr_value_iterator<AttrTy>(end()));
+  }
 };
 
 //===----------------------------------------------------------------------===//
index b48930c..7035c2e 100644 (file)
@@ -6,10 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -31,6 +32,7 @@
 #include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
+using namespace mlir::vector;
 
 template <typename T>
 static LLVM::LLVMType getPtrToElementType(T containerType,
@@ -723,15 +725,108 @@ private:
   }
 };
 
+// TODO(rriddle): Better support for attribute subtype forwarding + slicing.
+static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
+                                              unsigned dropFront = 0,
+                                              unsigned dropBack = 0) {
+  assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
+  auto range = arrayAttr.getAsRange<IntegerAttr>();
+  SmallVector<int64_t, 4> res;
+  res.reserve(arrayAttr.size() - dropFront - dropBack);
+  for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
+       it != eit; ++it)
+    res.push_back((*it).getValue().getSExtValue());
+  return res;
+}
+
+/// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank
+/// of `vector`.
+static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
+                        int64_t offset) {
+  auto vectorType = vector.getType().cast<VectorType>();
+  if (vectorType.getRank() > 1)
+    return rewriter.create<ExtractOp>(loc, vector, offset);
+  return rewriter.create<vector::ExtractElementOp>(
+      loc, vectorType.getElementType(), vector,
+      rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
+/// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank
+/// of `vector`.
+static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
+                       Value into, int64_t offset) {
+  auto vectorType = into.getType().cast<VectorType>();
+  if (vectorType.getRank() > 1)
+    return rewriter.create<InsertOp>(loc, from, into, offset);
+  return rewriter.create<vector::InsertElementOp>(
+      loc, vectorType, from, into,
+      rewriter.create<ConstantIndexOp>(loc, offset));
+}
+
+/// Progressive lowering of StridedSliceOp to either:
+///   1. extractelement + insertelement for the 1-D case
+///   2. extract + optional strided_slice + insert for the n-D case.
+class VectorStridedSliceOpRewritePattern
+    : public OpRewritePattern<StridedSliceOp> {
+public:
+  using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(StridedSliceOp op,
+                                     PatternRewriter &rewriter) const override {
+    auto dstType = op.getResult().getType().cast<VectorType>();
+
+    assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
+
+    int64_t offset =
+        op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+    int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
+    int64_t stride =
+        op.strides().getValue().front().cast<IntegerAttr>().getInt();
+
+    auto loc = op.getLoc();
+    auto elemType = dstType.getElementType();
+    assert(elemType.isIntOrIndexOrFloat());
+    Value zero = rewriter.create<ConstantOp>(loc, elemType,
+                                             rewriter.getZeroAttr(elemType));
+    Value res = rewriter.create<SplatOp>(loc, dstType, zero);
+    for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
+         off += stride, ++idx) {
+      Value extracted = extractOne(rewriter, loc, op.vector(), off);
+      if (op.offsets().getValue().size() > 1) {
+        StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>(
+            loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
+            getI64SubArray(op.sizes(), /* dropFront=*/1),
+            getI64SubArray(op.strides(), /* dropFront=*/1));
+        // Call matchAndRewrite recursively from within the pattern. This
+        // circumvents the current limitation that a given pattern cannot
+        // be called multiple times by the PatternRewrite infrastructure (to
+        // avoid infinite recursion, but in this case, infinite recursion
+        // cannot happen because the rank is strictly decreasing).
+        // TODO(rriddle, nicolasvasilache) Implement something like a hook for
+        // a potential function that must decrease and allow the same pattern
+        // multiple times.
+        auto success = matchAndRewrite(stridedSliceOp, rewriter);
+        (void)success;
+        assert(success && "Unexpected failure");
+        extracted = stridedSliceOp;
+      }
+      res = insertOne(rewriter, loc, extracted, res, idx);
+    }
+    rewriter.replaceOp(op, {res});
+    return matchSuccess();
+  }
+};
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  MLIRContext *ctx = converter.getDialect()->getContext();
+  patterns.insert<VectorStridedSliceOpRewritePattern>(ctx);
   patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
                   VectorExtractElementOpConversion, VectorExtractOpConversion,
                   VectorInsertElementOpConversion, VectorInsertOpConversion,
                   VectorOuterProductOpConversion, VectorTypeCastOpConversion,
-                  VectorPrintOpConversion>(converter.getDialect()->getContext(),
-                                           converter);
+                  VectorPrintOpConversion>(ctx, converter);
 }
 
 namespace {
index 1725a0b..3a00121 100644 (file)
@@ -423,3 +423,64 @@ func @vector_print_vector(%arg0: vector<2x2xf32>) {
 //       CHECK:    llvm.call @print_close() : () -> ()
 //       CHECK:    llvm.call @print_close() : () -> ()
 //       CHECK:    llvm.call @print_newline() : () -> ()
+
+
+func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) {
+// CHECK-LABEL: llvm.func @strided_slice(
+  
+  %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
+//       CHECK:    llvm.mlir.constant(2 : index) : !llvm.i64
+//       CHECK:    llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
+//       CHECK:    llvm.mlir.constant(0 : index) : !llvm.i64
+//       CHECK:    llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+//       CHECK:    llvm.mlir.constant(3 : index) : !llvm.i64
+//       CHECK:    llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
+//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:    llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+
+  %1 = vector.strided_slice %arg1 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
+//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm<"[2 x <8 x float>]">
+//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm<"[4 x <8 x float>]">
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"[2 x <8 x float>]">
+//       CHECK:    llvm.extractvalue %{{.*}}[3] : !llvm<"[4 x <8 x float>]">
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"[2 x <8 x float>]">
+
+  %2 = vector.strided_slice %arg1 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
+//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm<"[2 x <2 x float>]">
+//
+// Subvector vector<8xf32> @2
+//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <8 x float>]">
+//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
+//       CHECK:    llvm.mlir.constant(2 : index) : !llvm.i64
+//       CHECK:    llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+//       CHECK:    llvm.mlir.constant(0 : index) : !llvm.i64
+//       CHECK:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+//       CHECK:    llvm.mlir.constant(3 : index) : !llvm.i64
+//       CHECK:    llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+//       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <2 x float>]">
+//
+// Subvector vector<8xf32> @3
+//       CHECK:    llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <8 x float>]">
+//       CHECK:    llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+//       CHECK:    llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
+//       CHECK:    llvm.mlir.constant(2 : index) : !llvm.i64
+//       CHECK:    llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+//       CHECK:    llvm.mlir.constant(0 : index) : !llvm.i64
+//       CHECK:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+//       CHECK:    llvm.mlir.constant(3 : index) : !llvm.i64
+//       CHECK:    llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
+//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:    llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
+//       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <2 x float>]">
+
+  return
+}
+
+