Add VectorOps.StridedSliceOp
authorNicolas Vasilache <ntv@google.com>
Tue, 19 Nov 2019 20:22:00 +0000 (12:22 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 19 Nov 2019 20:22:34 +0000 (12:22 -0800)
The `vector.strided_slice` takes an n-D vector, k-D `offsets` integer array attribute, a
k-D `sizes` integer array attribute, a k-D `strides` integer array attribute and extracts
the n-D subvector at the proper offset.

Returns an n-D vector where the first k-D dimensions match the `sizes` attribute.
The returned subvector contains the elements starting at offset `offsets` and ending at
`offsets + sizes`.

Example:
```
  %1 = vector.strided_slice %0
      {offsets : [0, 2], sizes : [2, 4], strides : [1, 1]}:
    vector<4x8x16xf32> // returns a vector<2x4x16xf32>
```

This op will be useful for progressive lowering within the VectorOp dialect.

PiperOrigin-RevId: 281352749

mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir

index 9399cc8..12c5612 100644 (file)
@@ -76,6 +76,48 @@ def VectorExtractElementOp :
   }];
 }
 
+def VectorStridedSliceOp :
+  Vector_Op<"strided_slice", [NoSideEffect,
+     PredOpTrait<"operand and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
+               I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
+    Results<(outs AnyVector)> {
+  let summary = "strided_slice operation";
+  let description = [{
+    Takes an n-D vector, k-D `offsets` integer array attribute, a k-D `sizes`
+    integer array attribute, a k-D `strides` integer array attribute and
+    extracts the n-D subvector at the proper offset.
+
+    At the moment strides must contain only 1s.
+
+    Returns an n-D vector where the first k-D dimensions match the `sizes`
+    attribute. The returned subvector contains the elements starting at offset
+    `offsets` and ending at `offsets + sizes`.
+
+    Examples:
+    ```
+      %1 = vector.strided_slice %0
+          {offsets : [0, 2], sizes : [2, 4], strides : [1, 1]}:
+        vector<4x8x16xf32> to vector<2x4x16xf32>
+    ```
+
+    // TODO(Evolve to a range form syntax):
+    %1 = vector.strided_slice %0[0:2:1][2:4:1]
+      vector<4x8x16xf32> to vector<2x4x16xf32>
+  }];
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState &result, Value *source, " #
+    "ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, " #
+    "ArrayRef<int64_t> strides">];
+  let extraClassDeclaration = [{
+    static StringRef getOffsetsAttrName() { return "offsets"; }
+    static StringRef getSizesAttrName() { return "sizes"; }
+    static StringRef getStridesAttrName() { return "strides"; }
+    VectorType getVectorType(){ return vector()->getType().cast<VectorType>(); }
+  }];
+}
+
 def VectorOuterProductOp :
   Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
index c1244e2..ed0ed43 100644 (file)
@@ -92,7 +92,7 @@ static ParseResult parseVectorExtractElementOp(OpAsmParser &parser,
       static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
     return parser.emitError(
         attributeLoc,
-        "expected position attribute of rank smaller than vector");
+        "expected position attribute of rank smaller than vector rank");
 
   Type resType = inferExtractOpResultType(vectorType, positionAttr);
   result.attributes = attrs;
@@ -106,7 +106,7 @@ static LogicalResult verify(VectorExtractElementOp op) {
     return op.emitOpError("expected non-empty position attribute");
   if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
     return op.emitOpError(
-        "expected position attribute of rank smaller than vector");
+        "expected position attribute of rank smaller than vector rank");
   for (auto en : llvm::enumerate(positionAttr)) {
     auto attr = en.value().dyn_cast<IntegerAttr>();
     if (!attr || attr.getInt() < 0 ||
@@ -120,6 +120,180 @@ static LogicalResult verify(VectorExtractElementOp op) {
 }
 
 //===----------------------------------------------------------------------===//
+// VectorStridedSliceOp
+//===----------------------------------------------------------------------===//
+
+static Type inferVectorExtractRangeOpResultType(VectorType vectorType,
+                                                ArrayAttr offsets,
+                                                ArrayAttr sizes,
+                                                ArrayAttr strides) {
+  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
+  SmallVector<int64_t, 4> shape;
+  shape.reserve(vectorType.getRank());
+  unsigned idx = 0;
+  for (unsigned e = offsets.size(); idx < e; ++idx)
+    shape.push_back(sizes.getValue()[idx].cast<IntegerAttr>().getInt());
+  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
+    shape.push_back(vectorType.getShape()[idx]);
+
+  return VectorType::get(shape, vectorType.getElementType());
+}
+
+void VectorStridedSliceOp::build(Builder *builder, OperationState &result,
+                                 Value *source, ArrayRef<int64_t> offsets,
+                                 ArrayRef<int64_t> sizes,
+                                 ArrayRef<int64_t> strides) {
+  result.addOperands(source);
+  auto offsetsAttr = builder->getI64ArrayAttr(offsets);
+  auto sizesAttr = builder->getI64ArrayAttr(sizes);
+  auto stridesAttr = builder->getI64ArrayAttr(strides);
+  result.addTypes(
+      inferVectorExtractRangeOpResultType(source->getType().cast<VectorType>(),
+                                          offsetsAttr, sizesAttr, stridesAttr));
+  result.addAttribute(getOffsetsAttrName(), offsetsAttr);
+  result.addAttribute(getSizesAttrName(), sizesAttr);
+  result.addAttribute(getStridesAttrName(), stridesAttr);
+}
+
+static void print(OpAsmPrinter &p, VectorStridedSliceOp op) {
+  p << op.getOperationName() << " " << *op.vector();
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : " << op.vector()->getType() << " to " << op.getResult()->getType();
+}
+
+static ParseResult parseVectorStridedSliceOp(OpAsmParser &parser,
+                                             OperationState &result) {
+  llvm::SMLoc attributeLoc, typeLoc;
+  OpAsmParser::OperandType vector;
+  VectorType vectorType, resultVectorType;
+  return failure(parser.parseOperand(vector) ||
+                 parser.getCurrentLocation(&attributeLoc) ||
+                 parser.parseOptionalAttrDict(result.attributes) ||
+                 parser.getCurrentLocation(&typeLoc) ||
+                 parser.parseColonType(vectorType) ||
+                 parser.parseKeywordType("to", resultVectorType) ||
+                 parser.resolveOperand(vector, vectorType, result.operands) ||
+                 parser.addTypeToList(resultVectorType, result.types));
+}
+
+// TODO(ntv) Should be moved to Tablegen Confined attributes.
+static bool isIntegerArrayAttrSmallerThanShape(VectorStridedSliceOp op,
+                                               ArrayAttr arrayAttr,
+                                               ShapedType shape,
+                                               StringRef attrName) {
+  if (arrayAttr.size() > static_cast<unsigned>(shape.getRank())) {
+    op.emitOpError("expected ")
+        << attrName << " attribute of rank smaller than vector rank";
+    return false;
+  }
+  return true;
+}
+
+// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
+// interval. If `halfOpen` is true then the admissible interval is [min, max).
+// Otherwise, the admissible interval is [min, max].
+static bool isIntegerArrayAttrConfinedToRange(VectorStridedSliceOp op,
+                                              ArrayAttr arrayAttr, int64_t min,
+                                              int64_t max, StringRef attrName,
+                                              bool halfOpen = true) {
+  for (auto attr : arrayAttr) {
+    auto val = attr.cast<IntegerAttr>().getInt();
+    auto upper = max;
+    if (!halfOpen)
+      upper += 1;
+    if (val < min || val >= upper) {
+      op.emitOpError("expected ")
+          << attrName << " to be confined to [" << min << ", " << upper << ")";
+      return false;
+    }
+  }
+  return true;
+}
+
+// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
+// interval. If `halfOpen` is true then the admissible interval is [min, max).
+// Otherwise, the admissible interval is [min, max].
+static bool
+isIntegerArrayAttrConfinedToShape(VectorStridedSliceOp op, ArrayAttr arrayAttr,
+                                  ShapedType shape, StringRef attrName,
+                                  bool halfOpen = true, int64_t min = 0) {
+  assert(arrayAttr.size() <= static_cast<unsigned>(shape.getRank()));
+  for (auto it : llvm::zip(arrayAttr, shape.getShape())) {
+    auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
+    auto max = std::get<1>(it);
+    if (!halfOpen)
+      max += 1;
+    if (val < min || val >= max) {
+      op.emitOpError("expected ")
+          << attrName << " to be confined to [" << min << ", " << max << ")";
+      return false;
+    }
+  }
+  return true;
+}
+
+// Returns true if all integers in `arrayAttr` are in the interval [min, max}.
+// interval. If `halfOpen` is true then the admissible interval is [min, max).
+// Otherwise, the admissible interval is [min, max].
+static bool isSumOfIntegerArrayAttrConfinedToShape(
+    VectorStridedSliceOp op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
+    ShapedType shape, StringRef attrName1, StringRef attrName2,
+    bool halfOpen = true, int64_t min = 1) {
+  assert(arrayAttr1.size() <= static_cast<unsigned>(shape.getRank()));
+  assert(arrayAttr2.size() <= static_cast<unsigned>(shape.getRank()));
+  for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape.getShape())) {
+    auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
+    auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
+    auto max = std::get<2>(it);
+    if (!halfOpen)
+      max += 1;
+    if (val1 + val2 < 0 || val1 + val2 >= max) {
+      op.emitOpError("expected sum(")
+          << attrName1 << ", " << attrName2 << ") to be confined to [" << min
+          << ", " << max << ")";
+      return false;
+    }
+  }
+  return true;
+}
+
+static LogicalResult verify(VectorStridedSliceOp op) {
+  auto type = op.getVectorType();
+  auto offsets = op.offsets();
+  auto sizes = op.sizes();
+  auto strides = op.strides();
+  if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
+    op.emitOpError(
+        "expected offsets, sizes and strides attributes of same size");
+    return failure();
+  }
+
+  auto offName = VectorStridedSliceOp::getOffsetsAttrName();
+  auto sizesName = VectorStridedSliceOp::getSizesAttrName();
+  auto stridesName = VectorStridedSliceOp::getStridesAttrName();
+  if (!isIntegerArrayAttrSmallerThanShape(op, offsets, type, offName) ||
+      !isIntegerArrayAttrSmallerThanShape(op, sizes, type, sizesName) ||
+      !isIntegerArrayAttrSmallerThanShape(op, strides, type, stridesName) ||
+      !isIntegerArrayAttrConfinedToShape(op, offsets, type, offName) ||
+      !isIntegerArrayAttrConfinedToShape(op, sizes, type, sizesName,
+                                         /*halfOpen=*/false, /*min=*/1) ||
+      !isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
+                                         /*halfOpen=*/false) ||
+      !isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, type, offName,
+                                              sizesName, /*halfOpen=*/false))
+    return failure();
+
+  auto resultType = inferVectorExtractRangeOpResultType(
+      op.getVectorType(), op.offsets(), op.sizes(), op.strides());
+  if (op.getResult()->getType() != resultType) {
+    op.emitOpError("expected result type to be ") << resultType;
+    return failure();
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // VectorOuterProductOp
 //===----------------------------------------------------------------------===//
 
index 2db4cf5..d8ebb98 100644 (file)
@@ -231,6 +231,7 @@ func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
   // expected-error@+1 {{requires a projected permutation_map (at most one dim or the zero constant can appear in each result)}}
   vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d0 + 1)} : vector<128xf32>, memref<?x?xf32>
 }
+
 // -----
 
 func @test_vector.transfer_write(%arg0: memref<?x?x?xf32>) {
@@ -239,3 +240,66 @@ func @test_vector.transfer_write(%arg0: memref<?x?x?xf32>) {
   // expected-error@+1 {{requires a permutation_map that is a permutation (found one dim used more than once)}}
   vector.transfer_write %cst, %arg0[%c3, %c3, %c3] {permutation_map = (d0, d1, d2)->(d0, d0)} : vector<3x7xf32>, memref<?x?x?xf32>
 }
+
+// -----
+
+func @strided_slice(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected offsets, sizes and strides attributes of same size}}
+  %1 = vector.strided_slice %arg0 {offsets = [100], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
+}
+
+// -----
+
+func @strided_slice(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected offsets attribute of rank smaller than vector rank}}
+  %1 = vector.strided_slice %arg0 {offsets = [2, 2, 2, 2], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
+}
+
+// -----
+
+func @strided_slice(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected offsets attribute of rank smaller than vector rank}}
+  %1 = vector.strided_slice %arg0 {offsets = [2, 2, 2, 2], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
+}
+
+// -----
+
+func @strided_slice(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{op expected offsets to be confined to [0, 4)}}
+  %1 = vector.strided_slice %arg0 {offsets = [100], sizes = [100], strides = [100]} : vector<4x8x16xf32> to vector<100x8x16xf32>
+}
+
+// -----
+
+func @strided_slice(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{op expected sizes to be confined to [1, 5)}}
+  %1 = vector.strided_slice %arg0 {offsets = [2], sizes = [100], strides = [100]} : vector<4x8x16xf32> to vector<100x8x16xf32>
+}
+
+// -----
+
+func @strided_slice(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{op expected strides to be confined to [1, 2)}}
+  %1 = vector.strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
+}
+
+// -----
+
+func @strided_slice(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{op expected strides to be confined to [1, 2)}}
+  %1 = vector.strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
+}
+
+// -----
+
+func @strided_slice(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{op expected sum(offsets, sizes) to be confined to [1, 5)}}
+  %1 = vector.strided_slice %arg0 {offsets = [2], sizes = [3], strides = [1]} : vector<4x8x16xf32> to vector<3x8x16xf32>
+}
+
+// -----
+
+func @strided_slice(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{op expected result type to be 'vector<2x8x16xf32>'}}
+  %1 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
+}
index 77d40f5..7ad46c2 100644 (file)
@@ -41,3 +41,10 @@ func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8
   %1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
   return %1 : vector<4x8xf32>
 }
+
+// CHECK-LABEL: strided_slice
+func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
+  //      CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
+  %1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
+  return %1: vector<2x2x16xf32>
+}