Adds ExtractSlicesOp to the VectorOps dialect.
authorAndy Davis <andydavis@google.com>
Mon, 16 Dec 2019 14:38:33 +0000 (06:38 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 16 Dec 2019 14:39:09 +0000 (06:39 -0800)
ExtractSlicesOp extracts slices of its vector operand and with a specified tiling scheme.
This operation centralizes the tiling scheme around a single op, which simplifies vector op unrolling and subsequent pattern rewrite transformations.

PiperOrigin-RevId: 285761129

mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir

index d87f101..883e1bc 100644 (file)
@@ -294,6 +294,58 @@ def Vector_ExtractOp :
   }];
 }
 
+def Vector_ExtractSlicesOp :
+  Vector_Op<"extract_slices", [NoSideEffect]>,
+    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$sizes,
+                   I64ArrayAttr:$strides)>,
+    Results<(outs TupleOf<[AnyVector]>)> {
+  let summary = "vector extract slices operation";
+  let description = [{
+    Takes an N-d vector and returns a tuple of vector slices of 'vector',
+    based on 'sizes' and 'strides' parameters.
+
+    The arguments 'sizes' and 'strides' represent a specification for
+    generating the unrolling of 'vector' shape, which has all slices of shape
+    'sizes' except for slices at dimension boundaries when 'vector' dimension
+    sizes are not a multiple of 'sizes'.
+
+    Each slice is returned at the tuple element index corresponding to the
+    linear index of the slice w.r.t the unrolling scheme represented by 'sizes'.
+    Currently, only unit strides are supported.
+
+    Examples:
+    ```
+      %0 = vector.transfer_read ...: vector<4x2xf32>
+
+      %1 = vector.extract_slices %0, [2, 2], [1, 1]
+        : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
+
+      // Example with partial slices at dimension boundaries.
+      %2 = vector.transfer_read ...: vector<4x3xf32>
+
+      %3 = vector.extract_slices %2, [2, 2], [1, 1]
+        : vector<4x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>,
+                                     vector<2x2xf32>, vector<2x2xf32>>
+    ```
+  }];
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState &result, TupleType tupleType, " #
+    "Value *vector, ArrayRef<int64_t> sizes, " #
+    "ArrayRef<int64_t> strides">];
+  let extraClassDeclaration = [{
+    VectorType getSourceVectorType() {
+      return vector()->getType().cast<VectorType>();
+    }
+    TupleType getResultTupleType() {
+      return getResult()->getType().cast<TupleType>();
+    }
+    void getSizes(SmallVectorImpl<int64_t> &results);
+    void getStrides(SmallVectorImpl<int64_t> &results);
+    static StringRef getSizesAttrName() { return "sizes"; }
+    static StringRef getStridesAttrName() { return "strides"; }
+  }];
+}
+
 def Vector_InsertOp :
   Vector_Op<"insert", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",
@@ -757,4 +809,71 @@ def Vector_CreateMaskOp :
   let hasCanonicalizer = 1;
 }
 
+def Vector_TupleOp :
+  Vector_Op<"tuple", [NoSideEffect]>,
+    Arguments<(ins Variadic<AnyVector>:$vectors)>,
+    Results<(outs TupleOf<[AnyVector]>)> {
+  let summary = "make tuple of vectors operation";
+  let description = [{
+    Returns a tuple of its operands 'vectors'.
+
+    Note that this operation is used during the vector op unrolling
+    transformation and should be removed before lowering to lower-level
+    dialects.
+    
+
+    Examples:
+    ```
+      %0 = vector.transfer_read ... : vector<2x2xf32>
+      %1 = vector.transfer_read ... : vector<2x1xf32>
+      %2 = vector.transfer_read ... : vector<2x2xf32>
+      %3 = vector.transfer_read ... : vector<2x1xf32>
+
+      %4 = vector.tuple %0, %1, %2, %3
+        : vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>
+
+    ```
+  }];
+
+  let extraClassDeclaration = [{
+    TupleType getResultTupleType() {
+      return getResult()->getType().cast<TupleType>();
+    }
+  }];
+}
+
+def Vector_TupleGetOp :
+  Vector_Op<"tuple_get", [NoSideEffect]>,
+    Arguments<(ins TupleOf<[AnyVector]>:$vectors, APIntAttr:$index)>,
+    Results<(outs AnyVector)> {
+  let summary = "vector tuple get operation";
+  let description = [{
+    Returns the tuple element of 'vectors' at 'index'.
+
+    Note that this operation is used during the vector op unrolling
+    transformation and should be removed before lowering to lower-level
+    dialects.
+
+    Examples:
+    ```
+      %4 = vector.tuple %0, %1, %2, %3
+        : vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>>
+
+      %5 = vector.tuple_get %4, 1
+        : tuple<vector<2x2xf32>, vector<2x1xf32>,
+                vector<2x2xf32>, vector<2x1xf32>>
+    ```
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getResultVectorType() {
+      return getResult()->getType().cast<VectorType>();
+    }
+    unsigned getIndex() {
+      return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
+    }
+    static StringRef getIndexAttrName() { return "index"; }
+  }];
+}
+
 #endif // VECTOR_OPS
index ae5579d..2dfa456 100644 (file)
@@ -31,6 +31,8 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/Functional.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/StringSet.h"
 
 using namespace mlir;
@@ -420,6 +422,138 @@ static LogicalResult verify(vector::ExtractOp op) {
 }
 
 //===----------------------------------------------------------------------===//
+// ExtractSlicesOp
+//===----------------------------------------------------------------------===//
+
+void ExtractSlicesOp::build(Builder *builder, OperationState &result,
+                            TupleType tupleType, Value *vector,
+                            ArrayRef<int64_t> sizes,
+                            ArrayRef<int64_t> strides) {
+  result.addOperands(vector);
+  auto sizesAttr = builder->getI64ArrayAttr(sizes);
+  auto stridesAttr = builder->getI64ArrayAttr(strides);
+  result.addTypes(tupleType);
+  result.addAttribute(getSizesAttrName(), sizesAttr);
+  result.addAttribute(getStridesAttrName(), stridesAttr);
+}
+
+static ParseResult parseExtractSlicesOp(OpAsmParser &parser,
+                                        OperationState &result) {
+  OpAsmParser::OperandType operandInfo;
+  ArrayAttr sizesAttr;
+  StringRef sizesAttrName = ExtractSlicesOp::getSizesAttrName();
+  ArrayAttr stridesAttr;
+  StringRef stridesAttrName = ExtractSlicesOp::getStridesAttrName();
+  VectorType vectorType;
+  TupleType resultTupleType;
+  return failure(
+      parser.parseOperand(operandInfo) || parser.parseComma() ||
+      parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
+      parser.parseComma() ||
+      parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(vectorType) ||
+      parser.parseKeywordType("into", resultTupleType) ||
+      parser.resolveOperand(operandInfo, vectorType, result.operands) ||
+      parser.addTypeToList(resultTupleType, result.types));
+}
+
+static void print(OpAsmPrinter &p, ExtractSlicesOp op) {
+  p << op.getOperationName() << ' ' << *op.vector() << ", ";
+  p << op.sizes() << ", " << op.strides();
+  p.printOptionalAttrDict(
+      op.getAttrs(),
+      /*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(),
+                       ExtractSlicesOp::getStridesAttrName()});
+  p << " : " << op.vector()->getType();
+  p << " into " << op.getResultTupleType();
+}
+
+static LogicalResult
+isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
+                                 TupleType tupleType, ArrayRef<int64_t> sizes,
+                                 ArrayRef<int64_t> strides) {
+  // Check for non-unit strides.
+  // TODO(b/144845578) Support non-1 strides.
+  if (llvm::any_of(strides, [](int64_t s) { return s != 1; }))
+    return op->emitError("requires unit strides");
+  // Check that 'vectorType' rank matches rank of tuple element vectors.
+  unsigned rank = vectorType.getRank();
+  auto is_vector_type_of_rank = [&](Type t) {
+    return t.isa<VectorType>() && t.cast<VectorType>().getRank() == rank;
+  };
+  if (!llvm::all_of(tupleType.getTypes(), is_vector_type_of_rank))
+    return op->emitError("requires vector tuple elements of rank ") << rank;
+  // Check that 'sizes' and 'strides' are of size == 'rank'.
+  if (sizes.size() != rank || strides.size() != rank)
+    return op->emitError("requires sizes and strides of rank ") << rank;
+
+  // Compute the number of slices in each dimension.
+  // TODO(andydavis) Move this into a slice generation helper function.
+  auto shape = vectorType.getShape();
+  SmallVector<int64_t, 4> dimSliceCounts(rank);
+  for (unsigned i = 0; i < rank; ++i)
+    dimSliceCounts[i] = ceilDiv(shape[i], sizes[i]);
+  // Compute the strides between slices in each dimension.
+  SmallVector<int64_t, 4> sliceStrides(rank);
+  sliceStrides[rank - 1] = 1;
+  for (int i = rank - 2; i >= 0; --i)
+    sliceStrides[i] = sliceStrides[i + 1] * dimSliceCounts[i + 1];
+
+  // Generate each slice shape based on 'sizes', 'strides' and 'vectorType',
+  // and varify that the same matches the corresponding tuple element 'i'.
+  for (int64_t i = 0, e = tupleType.size(); i < e; ++i) {
+    // De-linearize w.r.t. 'sliceStrides'.
+    SmallVector<int64_t, 4> vectorOffsets(rank);
+    int64_t linearIndex = i;
+    for (unsigned j = 0; j < rank; ++j) {
+      vectorOffsets.push_back(linearIndex / sliceStrides[i]);
+      linearIndex %= sliceStrides[i];
+    }
+    // Convert from unrolled vector-space offsets to element-space offsets.
+    auto offsets = mlir::functional::zipMap(
+        [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
+    // Initialize 'sliceSizes' to target 'sizes'
+    SmallVector<int64_t, 4> sliceSizes(sizes.begin(), sizes.end());
+    for (unsigned j = 0; j < rank; ++j) {
+      // Based on 'offsets' and 'shape' clip some dim sizes for partial tiles.
+      sliceSizes[j] = std::min(sliceSizes[j], shape[j] - offsets[j]);
+    }
+    // Create slice VectorType type.
+    auto sliceVectorType =
+        VectorType::get(sliceSizes, vectorType.getElementType());
+    // Verify that 'sliceVectorType' matches tupleType.getTypes(i)
+    if (sliceVectorType != tupleType.getType(i))
+      return op->emitError("invalid tuple element type ") << sliceVectorType;
+  }
+  return success();
+}
+
+static LogicalResult verify(ExtractSlicesOp op) {
+  SmallVector<int64_t, 4> sizes;
+  op.getSizes(sizes);
+  SmallVector<int64_t, 4> strides;
+  op.getStrides(strides);
+  return isValidExtractOrInsertSlicesType(
+      op.getOperation(), op.getSourceVectorType(), op.getResultTupleType(),
+      sizes, strides);
+}
+
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+                                       SmallVectorImpl<int64_t> &results) {
+  for (auto attr : arrayAttr)
+    results.push_back(attr.cast<IntegerAttr>().getInt());
+}
+
+void ExtractSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
+  populateFromInt64AttrArray(sizes(), results);
+}
+
+void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
+  populateFromInt64AttrArray(strides(), results);
+}
+
+//===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
 
@@ -736,12 +870,6 @@ LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
   return success();
 }
 
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
-                                       SmallVectorImpl<int64_t> &results) {
-  for (auto attr : arrayAttr)
-    results.push_back(attr.cast<IntegerAttr>().getInt());
-}
-
 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
                                   MLIRContext *context) {
   auto attrs = functional::map(
@@ -1188,6 +1316,73 @@ static LogicalResult verify(TypeCastOp op) {
 }
 
 //===----------------------------------------------------------------------===//
+// TupleOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
+  SmallVector<OpAsmParser::OperandType, 4> operandInfos;
+  SmallVector<Type, 4> types;
+  auto loc = parser.getCurrentLocation();
+  auto *ctx = parser.getBuilder().getContext();
+  return failure(
+      parser.parseOperandList(operandInfos) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonTypeList(types) ||
+      parser.resolveOperands(operandInfos, types, loc, result.operands) ||
+      parser.addTypeToList(TupleType::get(types, ctx), result.types));
+}
+
+static void print(OpAsmPrinter &p, TupleOp op) {
+  p << op.getOperationName() << ' ';
+  p.printOperands(op.getOperands());
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : ";
+  interleaveComma(op.getOperation()->getOperandTypes(), p);
+}
+
+static LogicalResult verify(TupleOp op) { return success(); }
+
+//===----------------------------------------------------------------------===//
+// TupleGetOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTupleGetOp(OpAsmParser &parser,
+                                   OperationState &result) {
+  OpAsmParser::OperandType operandInfo;
+  IntegerAttr indexAttr;
+  StringRef indexAttrName = TupleGetOp::getIndexAttrName();
+  Type indexType = parser.getBuilder().getIndexType();
+  TupleType tupleType;
+  VectorType resultVectorType;
+  if (parser.parseOperand(operandInfo) || parser.parseComma() ||
+      parser.parseAttribute(indexAttr, indexType, indexAttrName,
+                            result.attributes) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(tupleType) ||
+      parser.resolveOperand(operandInfo, tupleType, result.operands))
+    return failure();
+  if (indexAttr.getInt() < 0 ||
+      indexAttr.getInt() >= static_cast<int64_t>(tupleType.size()))
+    return failure();
+  parser.addTypeToList(tupleType.getType(indexAttr.getInt()), result.types);
+  return success();
+}
+
+static void print(OpAsmPrinter &p, TupleGetOp op) {
+  p << op.getOperationName() << ' ' << *op.getOperand() << ", " << op.index();
+  p.printOptionalAttrDict(op.getAttrs(),
+                          /*elidedAttrs=*/{TupleGetOp::getIndexAttrName()});
+  p << " : " << op.getOperand()->getType();
+}
+
+static LogicalResult verify(TupleGetOp op) {
+  auto tupleType = op.getOperand()->getType().cast<TupleType>();
+  if (op.getIndex() < 0 || op.getIndex() >= tupleType.size())
+    return op.emitOpError("tuple get index out of range");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // ConstantMaskOp
 //===----------------------------------------------------------------------===//
 
index 4f56e94..892c10c 100644 (file)
@@ -704,3 +704,59 @@ func @constant_mask_with_zero_mask_dim_size() {
   %0 = vector.constant_mask [0, 2] : vector<4x3xi1>
   return
 }
+
+
+// -----
+
+func @extract_slices_non_unit_strides(%arg0 : vector<4x2xf32>) {
+  // expected-error@+1 {{requires unit strides}}
+  %0 = vector.extract_slices %arg0, [2, 2], [1, 3]
+    : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
+  return
+}
+
+// -----
+
+func @extract_slices_tuple_element_wrong_rank(%arg0 : vector<4x2xf32>) {
+  // expected-error@+1 {{requires vector tuple elements of rank 2}}
+  %0 = vector.extract_slices %arg0, [2, 2], [1, 1]
+    : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2x3xf32>>
+  return
+}
+
+// -----
+
+func @extract_slices_sizes_strides_wrong_rank(%arg0 : vector<4x2xf32>) {
+  // expected-error@+1 {{requires sizes and strides of rank}}
+  %0 = vector.extract_slices %arg0, [2, 2], [1, 1, 1]
+    : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
+  return
+}
+
+// -----
+
+func @extract_slices_invalid_tuple_element_type(%arg0 : vector<4x2xf32>) {
+  // expected-error@+1 {{invalid tuple element type}}
+  %0 = vector.extract_slices %arg0, [2, 2], [1, 1]
+    : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<4x2xf32>>
+  return
+}
+
+// -----
+
+func @tuple_of_non_vectors(%arg0 : vector<4x2xf32>) {
+  %c0 = constant 0 : index
+  // expected-error@+1 {{must be vector of any type values}}
+  %0 = vector.tuple %arg0, %c0 : vector<4x2xf32>, index
+  return
+}
+
+// -----
+
+func @tuple_get_of_non_vectors(%arg0 : tuple<vector<4x2xf32>, index>) {
+  // expected-error@+1 {{vector of any type values}}
+  %0 = vector.tuple_get %arg0, 0 : tuple<vector<4x2xf32>, index>
+  return
+}
+
+
index a5bafb4..1d28dea 100644 (file)
@@ -159,3 +159,15 @@ func @constant_vector_mask() {
   %0 = vector.constant_mask [3, 2] : vector<4x3xi1>
   return
 }
+
+// CHECK-LABEL: extract_slices
+func @extract_slices(%arg0 : vector<4x2xf32>)
+  -> (tuple<vector<2x2xf32>, vector<2x2xf32>>) {
+  // CHECK: vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
+  %0 = vector.extract_slices %arg0, [2, 2], [1, 1]
+    : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
+  %1 = vector.tuple_get %0, 0 : tuple<vector<2x2xf32>, vector<2x2xf32>>
+  %2 = vector.tuple_get %0, 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
+  %3 = vector.tuple %1, %2 : vector<2x2xf32>, vector<2x2xf32>
+  return %3 : tuple<vector<2x2xf32>, vector<2x2xf32>>
+}