Add InsertSlicesOp to the VectorOps dialect.
authorAndy Davis <andydavis@google.com>
Mon, 16 Dec 2019 20:56:06 +0000 (12:56 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 16 Dec 2019 20:56:38 +0000 (12:56 -0800)
PiperOrigin-RevId: 285830394

mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir

index eb05821..50bf581 100644 (file)
@@ -352,7 +352,7 @@ def Vector_ExtractSlicesOp :
 
       %3 = vector.extract_slices %2, [2, 2], [1, 1]
         : vector<4x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>,
-                                     vector<2x2xf32>, vector<2x2xf32>>
+                                     vector<2x2xf32>, vector<2x1xf32>>
     ```
   }];
   let builders = [OpBuilder<
@@ -439,6 +439,58 @@ def Vector_InsertOp :
   }];
 }
 
+def Vector_InsertSlicesOp :
+  Vector_Op<"insert_slices", [NoSideEffect]>,
+    Arguments<(ins TupleOf<[AnyVector]>:$vectors, I64ArrayAttr:$sizes,
+                   I64ArrayAttr:$strides)>,
+    Results<(outs AnyVector)> {
+  let summary = "vector insert slices operation";
+  let description = [{
+    Takes a tuple of vector slices and inserts them into the vector result
+    according to the 'sizes' and 'strides' parameters.
+
+    The arguments 'sizes' and 'strides' represent a specification for
+    generating the unrolling of 'vector' shape, which has all slices of shape
+    'sizes' except for slices at dimension boundaries when 'vector' dimension
+    sizes are not a multiple of 'sizes'.
+
+    Each slice in 'vectors' is at the tuple element index corresponding to the
+    linear index of the slice w.r.t the unrolling scheme represented by 'sizes'.
+    Currently, only unit strides are supported.
+
+    Examples:
+    ```
+      %0 = vector.extract_slices %0, [2, 2], [1, 1]
+        : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
+
+      %1 = vector.insert_slices %0, [2, 2], [1, 1]
+        : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+
+      // Example with partial slices at dimension boundaries.
+      %3 = vector.extract_slices %2, [2, 2], [1, 1]
+        : vector<4x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>,
+                                     vector<2x2xf32>, vector<2x1xf32>>
+
+      %4 = vector.insert_slices %3, [2, 2], [1, 1]
+        : tuple<vector<2x2xf32>, vector<2x1xf32>,
+                vector<2x2xf32>, vector<2x1xf32>> into vector<4x3xf32>
+    ```
+  }];
+
+  let extraClassDeclaration = [{
+    TupleType getSourceTupleType() {
+      return vectors()->getType().cast<TupleType>();
+    }
+    VectorType getResultVectorType() {
+      return getResult()->getType().cast<VectorType>();
+    }
+    void getSizes(SmallVectorImpl<int64_t> &results);
+    void getStrides(SmallVectorImpl<int64_t> &results);
+    static StringRef getSizesAttrName() { return "sizes"; }
+    static StringRef getStridesAttrName() { return "strides"; }
+  }];
+}
+
 def Vector_InsertStridedSliceOp :
   Vector_Op<"insert_strided_slice", [NoSideEffect,
     PredOpTrait<"operand #0 and result have same element type",
index fc8abd7..48fc0d4 100644 (file)
@@ -826,6 +826,60 @@ static LogicalResult verify(InsertOp op) {
 }
 
 //===----------------------------------------------------------------------===//
+// InsertSlicesOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
+                                       OperationState &result) {
+  OpAsmParser::OperandType operandInfo;
+  ArrayAttr sizesAttr;
+  StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName();
+  ArrayAttr stridesAttr;
+  StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName();
+  TupleType tupleType;
+  VectorType resultVectorType;
+  return failure(
+      parser.parseOperand(operandInfo) || parser.parseComma() ||
+      parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
+      parser.parseComma() ||
+      parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(tupleType) ||
+      parser.parseKeywordType("into", resultVectorType) ||
+      parser.resolveOperand(operandInfo, tupleType, result.operands) ||
+      parser.addTypeToList(resultVectorType, result.types));
+}
+
+static void print(OpAsmPrinter &p, InsertSlicesOp op) {
+  p << op.getOperationName() << ' ' << *op.vectors() << ", ";
+  p << op.sizes() << ", " << op.strides();
+  p.printOptionalAttrDict(
+      op.getAttrs(),
+      /*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
+                       InsertSlicesOp::getStridesAttrName()});
+  p << " : " << op.vectors()->getType();
+  p << " into " << op.getResultVectorType();
+}
+
+static LogicalResult verify(InsertSlicesOp op) {
+  SmallVector<int64_t, 4> sizes;
+  op.getSizes(sizes);
+  SmallVector<int64_t, 4> strides;
+  op.getStrides(strides);
+  return isValidExtractOrInsertSlicesType(
+      op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(),
+      sizes, strides);
+}
+
+void InsertSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
+  populateFromInt64AttrArray(sizes(), results);
+}
+
+void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
+  populateFromInt64AttrArray(strides(), results);
+}
+
+//===----------------------------------------------------------------------===//
 // InsertStridedSliceOp
 //===----------------------------------------------------------------------===//
 
index c04c8ea..3c2dd60 100644 (file)
@@ -783,4 +783,38 @@ func @tuple_get_of_non_vectors(%arg0 : tuple<vector<4x2xf32>, index>) {
   return
 }
 
+// -----
+
+func @insert_slices_non_unit_strides(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>) {
+  // expected-error@+1 {{requires unit strides}}
+  %0 = vector.insert_slices %arg0, [2, 2], [1, 3]
+    : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+  return
+}
+
+// -----
+
+func @insert_slices_tuple_element_wrong_rank(%arg0 : tuple<vector<2x2xf32>, vector<2x2x3xf32>>) {
+  // expected-error@+1 {{requires vector tuple elements of rank 2}}
+  %0 = vector.insert_slices %arg0, [2, 2], [1, 1]
+    : tuple<vector<2x2xf32>, vector<2x2x3xf32>> into vector<4x2xf32>
+  return
+}
 
+// -----
+
+func @insert_slices_sizes_strides_wrong_rank(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>) {
+  // expected-error@+1 {{requires sizes and strides of rank}}
+  %0 = vector.insert_slices %arg0, [2, 2], [1, 1, 1]
+    : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+  return
+}
+
+// -----
+
+func @insert_slices_invalid_tuple_element_type(%arg0 : tuple<vector<2x2xf32>, vector<4x2xf32>>) {
+  // expected-error@+1 {{invalid tuple element type}}
+  %0 = vector.insert_slices %arg0, [2, 2], [1, 1]
+    : tuple<vector<2x2xf32>, vector<4x2xf32>> into vector<4x2xf32>
+  return
+}
index 69af80f..f1db45e 100644 (file)
@@ -189,3 +189,12 @@ func @extract_slices(%arg0 : vector<4x2xf32>)
   %3 = vector.tuple %1, %2 : vector<2x2xf32>, vector<2x2xf32>
   return %3 : tuple<vector<2x2xf32>, vector<2x2xf32>>
 }
+
+// CHECK-LABEL: insert_slices
+func @insert_slices(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>)
+  -> (vector<4x2xf32>) {
+  // CHECK: vector.insert_slices %{{.*}}, [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+  %0 = vector.insert_slices %arg0, [2, 2], [1, 1]
+    : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}