Add a vector.InsertStridedSliceOp
authorNicolas Vasilache <ntv@google.com>
Mon, 25 Nov 2019 23:36:45 +0000 (15:36 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 25 Nov 2019 23:37:13 +0000 (15:37 -0800)
This new op is the counterpart of vector.StridedSliceOp and will be used for in the pattern rewrites for vector unrolling.

PiperOrigin-RevId: 282447414

mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir

index 138fa8a..a887a3e 100644 (file)
@@ -207,7 +207,7 @@ def Vector_InsertElementOp :
     ```
       %2 = vector.insertelement %0, %1[3 : i32]:
         vector<8x16xf32> into vector<4x8x16xf32>
-      %5 = vector.insertelement %3, %4[3 : i32, 3 : i32, 3 : i32]: 
+      %5 = vector.insertelement %3, %4[3 : i32, 3 : i32, 3 : i32]:
         f32 into vector<4x8x16xf32>
     ```
   }];
@@ -223,45 +223,47 @@ def Vector_InsertElementOp :
   }];
 }
 
-def Vector_StridedSliceOp :
-  Vector_Op<"strided_slice", [NoSideEffect,
-     PredOpTrait<"operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
-               I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
+def Vector_InsertStridedSliceOp :
+  Vector_Op<"insert_strided_slice", [NoSideEffect,
+    PredOpTrait<"operand #0 and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>,
+    PredOpTrait<"dest operand and result have same type",
+                 TCresIsSameAsOpBase<0, 1>>]>,
+    Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
+               I64ArrayAttr:$strides)>,
     Results<(outs AnyVector)> {
   let summary = "strided_slice operation";
   let description = [{
-    Takes an n-D vector, k-D `offsets` integer array attribute, a k-D `sizes`
-    integer array attribute, a k-D `strides` integer array attribute and
-    extracts the n-D subvector at the proper offset.
+    Takes a k-D source vector, an n-D destination vector (n >= k), n-D `offsets`
+    integer array attribute, a k-D `strides` integer array attribute and inserts
+    the k-D source vector as a strided subvector at the proper offset into the
+    n-D destination vector.
 
     At the moment strides must contain only 1s.
 
-    Returns an n-D vector where the first k-D dimensions match the `sizes`
-    attribute. The returned subvector contains the elements starting at offset
-    `offsets` and ending at `offsets + sizes`.
+    Returns an n-D vector that is a copy of the n-D destination vector in which
+    the last k-D dimensions contain the k-D source vector elements strided at
+    the proper location as specified by the offsets.
 
     Examples:
     ```
-      %1 = vector.strided_slice %0
-          {offsets : [0, 2], sizes : [2, 4], strides : [1, 1]}:
-        vector<4x8x16xf32> to vector<2x4x16xf32>
+      %2 = vector.insert_strided_slice %0, %1
+          {offsets : [0, 0, 2], strides : [1, 1]}:
+        vector<2x4xf32> into vector<16x4x8xf32>
     ```
-
-    // TODO(Evolve to a range form syntax):
-    %1 = vector.strided_slice %0[0:2:1][2:4:1]
-      vector<4x8x16xf32> to vector<2x4x16xf32>
   }];
   let builders = [OpBuilder<
-    "Builder *builder, OperationState &result, Value *source, " #
-    "ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, " #
-    "ArrayRef<int64_t> strides">];
+    "Builder *builder, OperationState &result, Value *source, Value *dest, " #
+    "ArrayRef<int64_t> offsets, ArrayRef<int64_t> strides">];
   let extraClassDeclaration = [{
     static StringRef getOffsetsAttrName() { return "offsets"; }
-    static StringRef getSizesAttrName() { return "sizes"; }
     static StringRef getStridesAttrName() { return "strides"; }
-    VectorType getVectorType(){ return vector()->getType().cast<VectorType>(); }
+    VectorType getSourceVectorType() {
+      return source()->getType().cast<VectorType>();
+    }
+    VectorType getDestVectorType() {
+      return dest()->getType().cast<VectorType>();
+    }
   }];
 }
 
@@ -304,6 +306,49 @@ def Vector_OuterProductOp :
   }];
 }
 
+def Vector_StridedSliceOp :
+  Vector_Op<"strided_slice", [NoSideEffect,
+    PredOpTrait<"operand and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
+               I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
+    Results<(outs AnyVector)> {
+  let summary = "strided_slice operation";
+  let description = [{
+    Takes an n-D vector, k-D `offsets` integer array attribute, a k-D `sizes`
+    integer array attribute, a k-D `strides` integer array attribute and
+    extracts the n-D subvector at the proper offset.
+
+    At the moment strides must contain only 1s.
+    // TODO(ntv) support non-1 strides.
+
+    Returns an n-D vector where the first k-D dimensions match the `sizes`
+    attribute. The returned subvector contains the elements starting at offset
+    `offsets` and ending at `offsets + sizes`.
+
+    Examples:
+    ```
+      %1 = vector.strided_slice %0
+          {offsets : [0, 2], sizes : [2, 4], strides : [1, 1]}:
+        vector<4x8x16xf32> to vector<2x4x16xf32>
+    ```
+
+    // TODO(ntv) Evolve to a range form syntax similar to:
+    %1 = vector.strided_slice %0[0:2:1][2:4:1]
+      vector<4x8x16xf32> to vector<2x4x16xf32>
+  }];
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState &result, Value *source, " #
+    "ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, " #
+    "ArrayRef<int64_t> strides">];
+  let extraClassDeclaration = [{
+    static StringRef getOffsetsAttrName() { return "offsets"; }
+    static StringRef getSizesAttrName() { return "sizes"; }
+    static StringRef getStridesAttrName() { return "strides"; }
+    VectorType getVectorType(){ return vector()->getType().cast<VectorType>(); }
+  }];
+}
+
 def Vector_TransferReadOp :
   Vector_Op<"transfer_read">,
     Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
index bda3c3f..789639c 100644 (file)
@@ -26,6 +26,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/Functional.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/StringSet.h"
 
@@ -295,8 +296,8 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
 
-static Type inferExtractOpResultType(VectorType vectorType,
-                                     ArrayAttr position) {
+static Type inferExtractElementOpResultType(VectorType vectorType,
+                                            ArrayAttr position) {
   if (static_cast<int64_t>(position.size()) == vectorType.getRank())
     return vectorType.getElementType();
   return VectorType::get(vectorType.getShape().drop_front(position.size()),
@@ -307,8 +308,8 @@ void ExtractElementOp::build(Builder *builder, OperationState &result,
                              Value *source, ArrayRef<int32_t> position) {
   result.addOperands(source);
   auto positionAttr = builder->getI32ArrayAttr(position);
-  result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
-                                           positionAttr));
+  result.addTypes(inferExtractElementOpResultType(
+      source->getType().cast<VectorType>(), positionAttr));
   result.addAttribute(getPositionAttrName(), positionAttr);
 }
 
@@ -342,7 +343,7 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
         attributeLoc,
         "expected position attribute of rank smaller than vector rank");
 
-  Type resType = inferExtractOpResultType(vectorType, positionAttr);
+  Type resType = inferExtractElementOpResultType(vectorType, positionAttr);
   result.attributes = attrs;
   return failure(parser.resolveOperand(vector, type, result.operands) ||
                  parser.addTypeToList(resType, result.types));
@@ -440,175 +441,170 @@ static LogicalResult verify(InsertElementOp op) {
 }
 
 //===----------------------------------------------------------------------===//
-// StridedSliceOp
+// InsertStridedSliceOp
 //===----------------------------------------------------------------------===//
 
-static Type inferExtractRangeOpResultType(VectorType vectorType,
-                                          ArrayAttr offsets, ArrayAttr sizes,
-                                          ArrayAttr strides) {
-  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
-  SmallVector<int64_t, 4> shape;
-  shape.reserve(vectorType.getRank());
-  unsigned idx = 0;
-  for (unsigned e = offsets.size(); idx < e; ++idx)
-    shape.push_back(sizes.getValue()[idx].cast<IntegerAttr>().getInt());
-  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
-    shape.push_back(vectorType.getShape()[idx]);
-
-  return VectorType::get(shape, vectorType.getElementType());
-}
-
-void StridedSliceOp::build(Builder *builder, OperationState &result,
-                           Value *source, ArrayRef<int64_t> offsets,
-                           ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
-  result.addOperands(source);
+void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
+                                 Value *source, Value *dest,
+                                 ArrayRef<int64_t> offsets,
+                                 ArrayRef<int64_t> strides) {
+  result.addOperands({source, dest});
   auto offsetsAttr = builder->getI64ArrayAttr(offsets);
-  auto sizesAttr = builder->getI64ArrayAttr(sizes);
   auto stridesAttr = builder->getI64ArrayAttr(strides);
-  result.addTypes(
-      inferExtractRangeOpResultType(source->getType().cast<VectorType>(),
-                                    offsetsAttr, sizesAttr, stridesAttr));
+  result.addTypes(dest->getType());
   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
-  result.addAttribute(getSizesAttrName(), sizesAttr);
   result.addAttribute(getStridesAttrName(), stridesAttr);
 }
 
-static void print(OpAsmPrinter &p, StridedSliceOp op) {
-  p << op.getOperationName() << " " << *op.vector();
+static void print(OpAsmPrinter &p, InsertStridedSliceOp op) {
+  p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
+    << " ";
   p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.vector()->getType() << " to " << op.getResult()->getType();
+  p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType();
 }
 
-static ParseResult parseStridedSliceOp(OpAsmParser &parser,
-                                       OperationState &result) {
-  llvm::SMLoc attributeLoc, typeLoc;
-  OpAsmParser::OperandType vector;
-  VectorType vectorType, resultVectorType;
-  return failure(parser.parseOperand(vector) ||
-                 parser.getCurrentLocation(&attributeLoc) ||
-                 parser.parseOptionalAttrDict(result.attributes) ||
-                 parser.getCurrentLocation(&typeLoc) ||
-                 parser.parseColonType(vectorType) ||
-                 parser.parseKeywordType("to", resultVectorType) ||
-                 parser.resolveOperand(vector, vectorType, result.operands) ||
-                 parser.addTypeToList(resultVectorType, result.types));
+static ParseResult parseInsertStridedSliceOp(OpAsmParser &parser,
+                                             OperationState &result) {
+  OpAsmParser::OperandType source, dest;
+  VectorType sourceVectorType, destVectorType;
+  return failure(
+      parser.parseOperand(source) || parser.parseComma() ||
+      parser.parseOperand(dest) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(sourceVectorType) ||
+      parser.parseKeywordType("into", destVectorType) ||
+      parser.resolveOperand(source, sourceVectorType, result.operands) ||
+      parser.resolveOperand(dest, destVectorType, result.operands) ||
+      parser.addTypeToList(destVectorType, result.types));
 }
 
 // TODO(ntv) Should be moved to Tablegen Confined attributes.
-static bool isIntegerArrayAttrSmallerThanShape(StridedSliceOp op,
-                                               ArrayAttr arrayAttr,
-                                               ShapedType shape,
-                                               StringRef attrName) {
-  if (arrayAttr.size() > static_cast<unsigned>(shape.getRank())) {
-    op.emitOpError("expected ")
-        << attrName << " attribute of rank smaller than vector rank";
-    return false;
-  }
-  return true;
+template <typename OpType>
+LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr,
+                                                 ArrayRef<int64_t> shape,
+                                                 StringRef attrName) {
+  if (arrayAttr.size() > shape.size())
+    return op.emitOpError("expected ")
+           << attrName << " attribute of rank smaller than vector rank";
+  return success();
 }
 
 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
 // interval. If `halfOpen` is true then the admissible interval is [min, max).
 // Otherwise, the admissible interval is [min, max].
-static bool isIntegerArrayAttrConfinedToRange(StridedSliceOp op,
-                                              ArrayAttr arrayAttr, int64_t min,
-                                              int64_t max, StringRef attrName,
-                                              bool halfOpen = true) {
+template <typename OpType>
+LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr,
+                                                int64_t min, int64_t max,
+                                                StringRef attrName,
+                                                bool halfOpen = true) {
   for (auto attr : arrayAttr) {
     auto val = attr.cast<IntegerAttr>().getInt();
     auto upper = max;
     if (!halfOpen)
       upper += 1;
-    if (val < min || val >= upper) {
-      op.emitOpError("expected ")
-          << attrName << " to be confined to [" << min << ", " << upper << ")";
-      return false;
-    }
+    if (val < min || val >= upper)
+      return op.emitOpError("expected ") << attrName << " to be confined to ["
+                                         << min << ", " << upper << ")";
   }
-  return true;
+  return success();
 }
 
 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
 // interval. If `halfOpen` is true then the admissible interval is [min, max).
 // Otherwise, the admissible interval is [min, max].
-static bool
-isIntegerArrayAttrConfinedToShape(StridedSliceOp op, ArrayAttr arrayAttr,
-                                  ShapedType shape, StringRef attrName,
+template <typename OpType>
+LogicalResult
+isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
+                                  ArrayRef<int64_t> shape, StringRef attrName,
                                   bool halfOpen = true, int64_t min = 0) {
-  assert(arrayAttr.size() <= static_cast<unsigned>(shape.getRank()));
-  for (auto it : llvm::zip(arrayAttr, shape.getShape())) {
+  assert(arrayAttr.size() <= shape.size());
+  unsigned index = 0;
+  for (auto it : llvm::zip(arrayAttr, shape)) {
     auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
     auto max = std::get<1>(it);
     if (!halfOpen)
       max += 1;
-    if (val < min || val >= max) {
-      op.emitOpError("expected ")
-          << attrName << " to be confined to [" << min << ", " << max << ")";
-      return false;
-    }
+    if (val < min || val >= max)
+      return op.emitOpError("expected ")
+             << attrName << " dimension " << index << " to be confined to ["
+             << min << ", " << max << ")";
+    ++index;
   }
-  return true;
+  return success();
 }
 
 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
 // interval. If `halfOpen` is true then the admissible interval is [min, max).
 // Otherwise, the admissible interval is [min, max].
-static bool
-isSumOfIntegerArrayAttrConfinedToShape(StridedSliceOp op, ArrayAttr arrayAttr1,
-                                       ArrayAttr arrayAttr2, ShapedType shape,
-                                       StringRef attrName1, StringRef attrName2,
-                                       bool halfOpen = true, int64_t min = 1) {
-  assert(arrayAttr1.size() <= static_cast<unsigned>(shape.getRank()));
-  assert(arrayAttr2.size() <= static_cast<unsigned>(shape.getRank()));
-  for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape.getShape())) {
+template <typename OpType>
+LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
+    OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
+    ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
+    bool halfOpen = true, int64_t min = 1) {
+  assert(arrayAttr1.size() <= shape.size());
+  assert(arrayAttr2.size() <= shape.size());
+  unsigned index = 0;
+  for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
     auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
     auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
     auto max = std::get<2>(it);
     if (!halfOpen)
       max += 1;
-    if (val1 + val2 < 0 || val1 + val2 >= max) {
-      op.emitOpError("expected sum(")
-          << attrName1 << ", " << attrName2 << ") to be confined to [" << min
-          << ", " << max << ")";
-      return false;
-    }
+    if (val1 + val2 < 0 || val1 + val2 >= max)
+      return op.emitOpError("expected sum(")
+             << attrName1 << ", " << attrName2 << ") dimension " << index
+             << " to be confined to [" << min << ", " << max << ")";
+    ++index;
   }
-  return true;
+  return success();
 }
 
-static LogicalResult verify(StridedSliceOp op) {
-  auto type = op.getVectorType();
+static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
+                                  MLIRContext *context) {
+  auto attrs = functional::map(
+      [context](int64_t v) -> Attribute {
+        return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
+      },
+      values);
+  return ArrayAttr::get(attrs, context);
+}
+
+static LogicalResult verify(InsertStridedSliceOp op) {
+  auto sourceVectorType = op.getSourceVectorType();
+  auto destVectorType = op.getDestVectorType();
   auto offsets = op.offsets();
-  auto sizes = op.sizes();
   auto strides = op.strides();
-  if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
-    op.emitOpError(
-        "expected offsets, sizes and strides attributes of same size");
+  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank())) {
+    op.emitOpError("expected offsets of same size as destination vector rank");
     return failure();
   }
-
-  auto offName = StridedSliceOp::getOffsetsAttrName();
-  auto sizesName = StridedSliceOp::getSizesAttrName();
-  auto stridesName = StridedSliceOp::getStridesAttrName();
-  if (!isIntegerArrayAttrSmallerThanShape(op, offsets, type, offName) ||
-      !isIntegerArrayAttrSmallerThanShape(op, sizes, type, sizesName) ||
-      !isIntegerArrayAttrSmallerThanShape(op, strides, type, stridesName) ||
-      !isIntegerArrayAttrConfinedToShape(op, offsets, type, offName) ||
-      !isIntegerArrayAttrConfinedToShape(op, sizes, type, sizesName,
-                                         /*halfOpen=*/false, /*min=*/1) ||
-      !isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
-                                         /*halfOpen=*/false) ||
-      !isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, type, offName,
-                                              sizesName, /*halfOpen=*/false))
+  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank())) {
+    op.emitOpError("expected strides of same size as source vector rank");
     return failure();
-
-  auto resultType = inferExtractRangeOpResultType(
-      op.getVectorType(), op.offsets(), op.sizes(), op.strides());
-  if (op.getResult()->getType() != resultType) {
-    op.emitOpError("expected result type to be ") << resultType;
+  }
+  if (sourceVectorType.getRank() > destVectorType.getRank()) {
+    op.emitOpError("expected source rank to be smaller than destination rank");
     return failure();
   }
 
+  auto sourceShape = sourceVectorType.getShape();
+  auto destShape = destVectorType.getShape();
+  SmallVector<int64_t, 4> sourceShapeAsDestShape(
+      destShape.size() - sourceShape.size(), 0);
+  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
+  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
+  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
+  if (failed(
+          isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
+      failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
+                                               /*halfOpen=*/false)) ||
+      failed(isSumOfIntegerArrayAttrConfinedToShape(
+          op, offsets,
+          makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
+          offName, "source vector shape",
+          /*halfOpen=*/false, /*min=*/1)))
+    return failure();
+
   return success();
 }
 
@@ -667,6 +663,104 @@ static LogicalResult verify(OuterProductOp op) {
 }
 
 //===----------------------------------------------------------------------===//
+// StridedSliceOp
+//===----------------------------------------------------------------------===//
+
+// Inference works as follows:
+//   1. Add 'sizes' from prefix of dims in 'offsets'.
+//   2. Add sizes from 'vectorType' for remaining dims.
+static Type inferStridedSliceOpResultType(VectorType vectorType,
+                                          ArrayAttr offsets, ArrayAttr sizes,
+                                          ArrayAttr strides) {
+  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
+  SmallVector<int64_t, 4> shape;
+  shape.reserve(vectorType.getRank());
+  unsigned idx = 0;
+  for (unsigned e = offsets.size(); idx < e; ++idx)
+    shape.push_back(sizes.getValue()[idx].cast<IntegerAttr>().getInt());
+  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
+    shape.push_back(vectorType.getShape()[idx]);
+
+  return VectorType::get(shape, vectorType.getElementType());
+}
+
+void StridedSliceOp::build(Builder *builder, OperationState &result,
+                           Value *source, ArrayRef<int64_t> offsets,
+                           ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
+  result.addOperands(source);
+  auto offsetsAttr = builder->getI64ArrayAttr(offsets);
+  auto sizesAttr = builder->getI64ArrayAttr(sizes);
+  auto stridesAttr = builder->getI64ArrayAttr(strides);
+  result.addTypes(
+      inferStridedSliceOpResultType(source->getType().cast<VectorType>(),
+                                    offsetsAttr, sizesAttr, stridesAttr));
+  result.addAttribute(getOffsetsAttrName(), offsetsAttr);
+  result.addAttribute(getSizesAttrName(), sizesAttr);
+  result.addAttribute(getStridesAttrName(), stridesAttr);
+}
+
+static void print(OpAsmPrinter &p, StridedSliceOp op) {
+  p << op.getOperationName() << " " << *op.vector();
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : " << op.vector()->getType() << " to " << op.getResult()->getType();
+}
+
+static ParseResult parseStridedSliceOp(OpAsmParser &parser,
+                                       OperationState &result) {
+  llvm::SMLoc attributeLoc, typeLoc;
+  OpAsmParser::OperandType vector;
+  VectorType vectorType, resultVectorType;
+  return failure(parser.parseOperand(vector) ||
+                 parser.getCurrentLocation(&attributeLoc) ||
+                 parser.parseOptionalAttrDict(result.attributes) ||
+                 parser.getCurrentLocation(&typeLoc) ||
+                 parser.parseColonType(vectorType) ||
+                 parser.parseKeywordType("to", resultVectorType) ||
+                 parser.resolveOperand(vector, vectorType, result.operands) ||
+                 parser.addTypeToList(resultVectorType, result.types));
+}
+
+static LogicalResult verify(StridedSliceOp op) {
+  auto type = op.getVectorType();
+  auto offsets = op.offsets();
+  auto sizes = op.sizes();
+  auto strides = op.strides();
+  if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
+    op.emitOpError(
+        "expected offsets, sizes and strides attributes of same size");
+    return failure();
+  }
+
+  auto shape = type.getShape();
+  auto offName = StridedSliceOp::getOffsetsAttrName();
+  auto sizesName = StridedSliceOp::getSizesAttrName();
+  auto stridesName = StridedSliceOp::getStridesAttrName();
+  if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
+      failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
+      failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
+                                                stridesName)) ||
+      failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
+      failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
+                                               /*halfOpen=*/false,
+                                               /*min=*/1)) ||
+      failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
+                                               /*halfOpen=*/false)) ||
+      failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
+                                                    offName, sizesName,
+                                                    /*halfOpen=*/false)))
+    return failure();
+
+  auto resultType = inferStridedSliceOpResultType(
+      op.getVectorType(), op.offsets(), op.sizes(), op.strides());
+  if (op.getResult()->getType() != resultType) {
+    op.emitOpError("expected result type to be ") << resultType;
+    return failure();
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // TransferReadOp
 //===----------------------------------------------------------------------===//
 template <typename EmitFun>
index c0f4d16..60d5774 100644 (file)
@@ -278,6 +278,48 @@ func @test_vector.transfer_write(%arg0: memref<?x?x?xf32>) {
 
 // -----
 
+func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected offsets of same size as destination vector rank}}
+  %1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
+}
+
+// -----
+
+func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected strides of same size as source vector rank}}
+  %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 2], strides = [1]} : vector<4x4xf32> into vector<4x8x16xf32>
+}
+
+// -----
+
+func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected source rank to be smaller than destination rank}}
+  %1 = vector.insert_strided_slice %b, %a {offsets = [2, 2], strides = [1, 1, 1]} : vector<4x8x16xf32> into vector<4x4xf32>
+}
+
+// -----
+
+func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
+  // expected-error@+1 {{op expected offsets dimension 0 to be confined to [0, 4)}}
+  %1 = vector.insert_strided_slice %a, %b {offsets = [100,100,100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
+}
+
+// -----
+
+func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
+  // expected-error@+1 {{op expected strides to be confined to [1, 2)}}
+  %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 2], strides = [100, 100]} : vector<4x4xf32> into vector<4x8x16xf32>
+}
+
+// -----
+
+func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
+  // expected-error@+1 {{op expected sum(offsets, source vector shape) dimension 1 to be confined to [1, 9)}}
+  %1 = vector.insert_strided_slice %a, %b {offsets = [2, 7, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
+}
+
+// -----
+
 func @strided_slice(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected offsets, sizes and strides attributes of same size}}
   %1 = vector.strided_slice %arg0 {offsets = [100], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
@@ -300,14 +342,14 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) {
 // -----
 
 func @strided_slice(%arg0: vector<4x8x16xf32>) {
-  // expected-error@+1 {{op expected offsets to be confined to [0, 4)}}
+  // expected-error@+1 {{op expected offsets dimension 0 to be confined to [0, 4)}}
   %1 = vector.strided_slice %arg0 {offsets = [100], sizes = [100], strides = [100]} : vector<4x8x16xf32> to vector<100x8x16xf32>
 }
 
 // -----
 
 func @strided_slice(%arg0: vector<4x8x16xf32>) {
-  // expected-error@+1 {{op expected sizes to be confined to [1, 5)}}
+  // expected-error@+1 {{op expected sizes dimension 0 to be confined to [1, 5)}}
   %1 = vector.strided_slice %arg0 {offsets = [2], sizes = [100], strides = [100]} : vector<4x8x16xf32> to vector<100x8x16xf32>
 }
 
@@ -328,7 +370,7 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) {
 // -----
 
 func @strided_slice(%arg0: vector<4x8x16xf32>) {
-  // expected-error@+1 {{op expected sum(offsets, sizes) to be confined to [1, 5)}}
+  // expected-error@+1 {{op expected sum(offsets, sizes) dimension 0 to be confined to [1, 5)}}
   %1 = vector.strided_slice %arg0 {offsets = [2], sizes = [3], strides = [1]} : vector<4x8x16xf32> to vector<3x8x16xf32>
 }
 
index 09b8815..a2b1ac3 100644 (file)
@@ -22,7 +22,7 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
   return
 }
 
-// CHECK-LABEL: extractelement
+// CHECK-LABEL: @extractelement
 func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
   //      CHECK: vector.extractelement {{.*}}[3 : i32] : vector<4x8x16xf32>
   %1 = vector.extractelement %arg0[3 : i32] : vector<4x8x16xf32>
@@ -33,7 +33,7 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x
   return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
 }
 
-// CHECK-LABEL: insertelement
+// CHECK-LABEL: @insertelement
 func @insertelement(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) {
   //      CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
   %1 = vector.insertelement %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
@@ -44,7 +44,7 @@ func @insertelement(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vec
   return
 }
 
-// CHECK-LABEL: outerproduct
+// CHECK-LABEL: @outerproduct
 func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
   //     CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
   %0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
@@ -53,7 +53,14 @@ func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8
   return %1 : vector<4x8xf32>
 }
 
-// CHECK-LABEL: strided_slice
+// CHECK-LABEL: @insert_strided_slice
+func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
+  //      CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
+  %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
+  return
+}
+
+// CHECK-LABEL: @strided_slice
 func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
   //      CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
   %1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>