}];
}
+// TODO(andydavis) Add transformation which decomposes ReshapeOp into an
+// optimized sequence of vector rotate/shuffle/select operations.
+def Vector_ReshapeOp :
+ Vector_Op<"reshape", [AttrSizedOperandSegments, NoSideEffect]>,
+ Arguments<(ins AnyVector:$vector, Variadic<Index>:$input_shape,
+ Variadic<Index>:$output_shape,
+ I64ArrayAttr:$fixed_vector_sizes,
+ I32ElementsAttr:$operand_segment_sizes)>,
+ Results<(outs AnyVector)> {
+ let summary = "vector reshape operation";
+ let description = [{
+ Reshapes its vector operand from 'input_shape' to 'output_shape' maintaining
+ fixed vector dimension 'fixed_vector_sizes' on the innermost vector
+ dimensions.
+
+ The parameters 'input_shape' and 'output_shape' represent valid data shapes
+ across fixed vector shapes. For example, if a vector has a valid data
+ shape [6] with fixed vector size [8], then the valid data elements are
+ assumed to be stored at the beginning of the vector with the remaining
+ vector elements undefined.
+
+ In the examples below, valid data elements are represented by an alphabetic
+ character, and undefined data elements are represented by '-'.
+
+ Example
+
+ vector<1x8xf32> with valid data shape [6], fixed vector sizes [8]
+
+ input: [a, b, c, d, e, f]
+
+ layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
+
+ vector layout: [a, b, c, d, e, f, -, -]
+
+ Example
+
+ vector<2x8xf32> with valid data shape [10], fixed vector sizes [8]
+
+ input: [a, b, c, d, e, f, g, h, i, j]
+
+ layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
+
+ vector layout: [[a, b, c, d, e, f, g, h],
+ [i, j, -, -, -, -, -, -]]
+
+ Example
+
+ vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes
+ [2, 3]
+
+ input: [[a, b, c, d, e],
+ [f, g, h, i, j],
+ [k, l, m, n, o]]
+
+ layout map: (d0, d1) -> (d0 floordiv 3, d1 floordiv 5,
+ d0 mod 3, d1 mod 5)
+
+ vector layout: [[[[a, b, c],
+ [f, g, h]]
+ [[d, e, -],
+ [i, j, -]]],
+ [[[k, l, m],
+ [-, -, -]]
+ [[n, o, -],
+ [-, -, -]]]]
+
+ Example
+
+ %1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+
+ input: [[a, b, c, d, e, f],
+ [g, h, i, j, k, l],
+ [m, n, o, p, q, r]]
+
+ layout map: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)
+
+
+ Input vector: [[[a, b, c, d],
+ [e, f, -, -]],
+ [[g, h, i, j],
+ [k, l, -, -]],
+ [[m, n, o, p],
+ [q, r, -, -]]]
+
+ Output vector: [[[a, b, c, d],
+ [e, f, g, h],
+ [i, -, -, -]],
+ [[j, k, l, m],
+ [n, o, p, q],
+ [r, -, -, -]]]
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getInputVectorType() {
+ return vector()->getType().cast<VectorType>();
+ }
+ VectorType getOutputVectorType() {
+ return getResult()->getType().cast<VectorType>();
+ }
+
+ /// Returns as integer value the number of input shape operands.
+ int64_t getNumInputShapeSizes() { return input_shape().size(); }
+
+ /// Returns as integer value the number of output shape operands.
+ int64_t getNumOutputShapeSizes() { return output_shape().size(); }
+
+ void getFixedVectorSizes(SmallVectorImpl<int64_t> &results);
+
+ static StringRef getFixedVectorSizesAttrName() {
+ return "fixed_vector_sizes";
+ }
+ static StringRef getInputShapeAttrName() { return "input_shape"; }
+ static StringRef getOutputShapeAttrName() { return "output_shape"; }
+ }];
+}
+
def Vector_StridedSliceOp :
Vector_Op<"strided_slice", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
}
//===----------------------------------------------------------------------===//
+// ReshapeOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, ReshapeOp op) {
+ p << op.getOperationName() << " " << *op.vector() << ", [" << op.input_shape()
+ << "], [" << op.output_shape() << "], " << op.fixed_vector_sizes();
+ SmallVector<StringRef, 2> elidedAttrs = {
+ ReshapeOp::getOperandSegmentSizeAttr(),
+ ReshapeOp::getFixedVectorSizesAttrName()};
+ p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
+ p << " : " << op.getInputVectorType() << " to " << op.getOutputVectorType();
+}
+
+// TODO(b/146516564) Consider passing number of inner vector dimensions that
+// are fixed, instead of their values in 'fixesVectorSizes' array attr.
+//
+// operation ::= ssa-id `=` `vector.reshape` ssa-use, `[` ssa-use-list `]`,
+// `[` ssa-use-list `]`, `[` array-attribute `]`
+// `:` vector-type 'to' vector-type
+//
+static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType inputInfo;
+ SmallVector<OpAsmParser::OperandType, 4> inputShapeInfo;
+ SmallVector<OpAsmParser::OperandType, 4> outputShapeInfo;
+ ArrayAttr fixedVectorSizesAttr;
+ StringRef attrName = ReshapeOp::getFixedVectorSizesAttrName();
+ auto indexType = parser.getBuilder().getIndexType();
+ if (parser.parseOperand(inputInfo) || parser.parseComma() ||
+ parser.parseOperandList(inputShapeInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseComma() ||
+ parser.parseOperandList(outputShapeInfo,
+ OpAsmParser::Delimiter::Square) ||
+ parser.parseComma()) {
+ return failure();
+ }
+
+ auto builder = parser.getBuilder();
+ result.addAttribute(
+ ReshapeOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr({1, static_cast<int32_t>(inputShapeInfo.size()),
+ static_cast<int32_t>(outputShapeInfo.size())}));
+ Type inputType;
+ Type outputType;
+ return failure(
+ parser.parseAttribute(fixedVectorSizesAttr, attrName,
+ result.attributes) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(inputType) ||
+ parser.resolveOperand(inputInfo, inputType, result.operands) ||
+ parser.resolveOperands(inputShapeInfo, indexType, result.operands) ||
+ parser.resolveOperands(outputShapeInfo, indexType, result.operands) ||
+ parser.parseKeywordType("to", outputType) ||
+ parser.addTypeToList(outputType, result.types));
+}
+
+static LogicalResult verify(ReshapeOp op) {
+ // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
+ auto inputVectorType = op.getInputVectorType();
+ auto outputVectorType = op.getOutputVectorType();
+ int64_t inputShapeRank = op.getNumInputShapeSizes();
+ int64_t outputShapeRank = op.getNumOutputShapeSizes();
+ SmallVector<int64_t, 4> fixedVectorSizes;
+ op.getFixedVectorSizes(fixedVectorSizes);
+ int64_t numFixedVectorSizes = fixedVectorSizes.size();
+
+ if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
+ return op.emitError("invalid input shape for vector type ")
+ << inputVectorType;
+
+ if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
+ return op.emitError("invalid output shape for vector type ")
+ << outputVectorType;
+
+ // Verify that the 'fixedVectorSizes' match a input/output vector shape
+ // suffix.
+ unsigned inputVectorRank = inputVectorType.getRank();
+ for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
+ unsigned index = inputVectorRank - numFixedVectorSizes - i;
+ if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
+ return op.emitError("fixed vector size must match input vector for dim ")
+ << i;
+ }
+
+ unsigned outputVectorRank = outputVectorType.getRank();
+ for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
+ unsigned index = outputVectorRank - numFixedVectorSizes - i;
+ if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
+ return op.emitError("fixed vector size must match output vector for dim ")
+ << i;
+ }
+
+ // If all shape operands are produced by constant ops, verify that product
+ // of dimensions for input/output shape match.
+ auto isDefByConstant = [](Value *operand) {
+ return isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
+ };
+ if (llvm::all_of(op.input_shape(), isDefByConstant) &&
+ llvm::all_of(op.output_shape(), isDefByConstant)) {
+ int64_t numInputElements = 1;
+ for (auto *operand : op.input_shape())
+ numInputElements *=
+ cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
+ int64_t numOutputElements = 1;
+ for (auto *operand : op.output_shape())
+ numOutputElements *=
+ cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
+ if (numInputElements != numOutputElements)
+ return op.emitError("product of input and output shape sizes must match");
+ }
+ return success();
+}
+
+void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
+ populateFromInt64AttrArray(fixed_vector_sizes(), results);
+}
+
+//===----------------------------------------------------------------------===//
// StridedSliceOp
//===----------------------------------------------------------------------===//
%0 = vector.print %arg0 : f32
return %0
}
+
+// -----
+
+func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{invalid input shape for vector type}}
+ %1 = vector.reshape %arg0, [%c3, %c6, %c3], [%c2, %c9], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+}
+
+// -----
+
+func @reshape_bad_output_shape(%arg0 : vector<3x2x4xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{invalid output shape for vector type}}
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9, %c3], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+}
+
+// -----
+
+func @reshape_bad_input_output_shape_product(%arg0 : vector<3x2x4xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{product of input and output shape sizes must match}}
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c6], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+}
+
+// -----
+
+func @reshape_bad_input_fixed_size(%arg0 : vector<3x2x5xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{fixed vector size must match input vector for dim 0}}
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
+ : vector<3x2x5xf32> to vector<2x3x4xf32>
+}
+
+// -----
+
+func @reshape_bad_output_fixed_size(%arg0 : vector<3x2x4xf32>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c9 = constant 9 : index
+ // expected-error@+1 {{fixed vector size must match output vector for dim 0}}
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
+ : vector<3x2x4xf32> to vector<2x3x5xf32>
+}
vector.print %arg0 : vector<8x4xf32>
return
}
+
+// CHECK-LABEL: reshape
+func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
+ // CHECK: %[[C2:.*]] = constant 2 : index
+ %c2 = constant 2 : index
+ // CHECK: %[[C3:.*]] = constant 3 : index
+ %c3 = constant 3 : index
+ // CHECK: %[[C6:.*]] = constant 6 : index
+ %c6 = constant 6 : index
+ // CHECK: %[[C9:.*]] = constant 9 : index
+ %c9 = constant 9 : index
+ // CHECK: vector.reshape %{{.*}}, [%[[C3]], %[[C6]]], [%[[C2]], %[[C9]]], [4] : vector<3x2x4xf32> to vector<2x3x4xf32>
+ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+
+ return %1 : vector<2x3x4xf32>
+}