From: Andy Davis Date: Thu, 19 Dec 2019 20:22:35 +0000 (-0800) Subject: [VectorOps] Add vector ReshapeOp to the VectorOps dialect. X-Git-Tag: llvmorg-11-init~1466^2~22 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1d798b1d27fb150de47266b009a414db46344f5a;p=platform%2Fupstream%2Fllvm.git [VectorOps] Add vector ReshapeOp to the VectorOps dialect. Adds vector ReshapeOp to the VectorOps dialect. An aggregate vector reshape operation, which aggregates multiple hardware vectors, can enable optimizations during decomposition (e.g. loading one input hardware vector and performing multiple rotate and scatter store operations to the vector output). PiperOrigin-RevId: 286440658 --- diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 98c0610..7dcac62 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -574,6 +574,123 @@ def Vector_OuterProductOp : }]; } +// TODO(andydavis) Add transformation which decomposes ReshapeOp into an +// optimized sequence of vector rotate/shuffle/select operations. +def Vector_ReshapeOp : + Vector_Op<"reshape", [AttrSizedOperandSegments, NoSideEffect]>, + Arguments<(ins AnyVector:$vector, Variadic:$input_shape, + Variadic:$output_shape, + I64ArrayAttr:$fixed_vector_sizes, + I32ElementsAttr:$operand_segment_sizes)>, + Results<(outs AnyVector)> { + let summary = "vector reshape operation"; + let description = [{ + Reshapes its vector operand from 'input_shape' to 'output_shape' maintaining + fixed vector dimension 'fixed_vector_sizes' on the innermost vector + dimensions. + + The parameters 'input_shape' and 'output_shape' represent valid data shapes + across fixed vector shapes. For example, if a vector has a valid data + shape [6] with fixed vector size [8], then the valid data elements are + assumed to be stored at the beginning of the vector with the remaining + vector elements undefined. + + In the examples below, valid data elements are represented by an alphabetic + character, and undefined data elements are represented by '-'. + + Example + + vector<1x8xf32> with valid data shape [6], fixed vector sizes [8] + + input: [a, b, c, d, e, f] + + layout map: (d0) -> (d0 floordiv 8, d0 mod 8) + + vector layout: [a, b, c, d, e, f, -, -] + + Example + + vector<2x8xf32> with valid data shape [10], fixed vector sizes [8] + + input: [a, b, c, d, e, f, g, h, i, j] + + layout map: (d0) -> (d0 floordiv 8, d0 mod 8) + + vector layout: [[a, b, c, d, e, f, g, h], + [i, j, -, -, -, -, -, -]] + + Example + + vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes + [2, 3] + + input: [[a, b, c, d, e], + [f, g, h, i, j], + [k, l, m, n, o]] + + layout map: (d0, d1) -> (d0 floordiv 3, d1 floordiv 5, + d0 mod 3, d1 mod 5) + + vector layout: [[[[a, b, c], + [f, g, h]] + [[d, e, -], + [i, j, -]]], + [[[k, l, m], + [-, -, -]] + [[n, o, -], + [-, -, -]]]] + + Example + + %1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> + + input: [[a, b, c, d, e, f], + [g, h, i, j, k, l], + [m, n, o, p, q, r]] + + layout map: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4) + + + Input vector: [[[a, b, c, d], + [e, f, -, -]], + [[g, h, i, j], + [k, l, -, -]], + [[m, n, o, p], + [q, r, -, -]]] + + Output vector: [[[a, b, c, d], + [e, f, g, h], + [i, -, -, -]], + [[j, k, l, m], + [n, o, p, q], + [r, -, -, -]]] + }]; + + let extraClassDeclaration = [{ + VectorType getInputVectorType() { + return vector()->getType().cast(); + } + VectorType getOutputVectorType() { + return getResult()->getType().cast(); + } + + /// Returns as integer value the number of input shape operands. + int64_t getNumInputShapeSizes() { return input_shape().size(); } + + /// Returns as integer value the number of output shape operands. + int64_t getNumOutputShapeSizes() { return output_shape().size(); } + + void getFixedVectorSizes(SmallVectorImpl &results); + + static StringRef getFixedVectorSizesAttrName() { + return "fixed_vector_sizes"; + } + static StringRef getInputShapeAttrName() { return "input_shape"; } + static StringRef getOutputShapeAttrName() { return "output_shape"; } + }]; +} + def Vector_StridedSliceOp : Vector_Op<"strided_slice", [NoSideEffect, PredOpTrait<"operand and result have same element type", diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index ff28334..541b542 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1108,6 +1108,123 @@ static LogicalResult verify(OuterProductOp op) { } //===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, ReshapeOp op) { + p << op.getOperationName() << " " << *op.vector() << ", [" << op.input_shape() + << "], [" << op.output_shape() << "], " << op.fixed_vector_sizes(); + SmallVector elidedAttrs = { + ReshapeOp::getOperandSegmentSizeAttr(), + ReshapeOp::getFixedVectorSizesAttrName()}; + p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); + p << " : " << op.getInputVectorType() << " to " << op.getOutputVectorType(); +} + +// TODO(b/146516564) Consider passing number of inner vector dimensions that +// are fixed, instead of their values in 'fixesVectorSizes' array attr. +// +// operation ::= ssa-id `=` `vector.reshape` ssa-use, `[` ssa-use-list `]`, +// `[` ssa-use-list `]`, `[` array-attribute `]` +// `:` vector-type 'to' vector-type +// +static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType inputInfo; + SmallVector inputShapeInfo; + SmallVector outputShapeInfo; + ArrayAttr fixedVectorSizesAttr; + StringRef attrName = ReshapeOp::getFixedVectorSizesAttrName(); + auto indexType = parser.getBuilder().getIndexType(); + if (parser.parseOperand(inputInfo) || parser.parseComma() || + parser.parseOperandList(inputShapeInfo, OpAsmParser::Delimiter::Square) || + parser.parseComma() || + parser.parseOperandList(outputShapeInfo, + OpAsmParser::Delimiter::Square) || + parser.parseComma()) { + return failure(); + } + + auto builder = parser.getBuilder(); + result.addAttribute( + ReshapeOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast(inputShapeInfo.size()), + static_cast(outputShapeInfo.size())})); + Type inputType; + Type outputType; + return failure( + parser.parseAttribute(fixedVectorSizesAttr, attrName, + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(inputType) || + parser.resolveOperand(inputInfo, inputType, result.operands) || + parser.resolveOperands(inputShapeInfo, indexType, result.operands) || + parser.resolveOperands(outputShapeInfo, indexType, result.operands) || + parser.parseKeywordType("to", outputType) || + parser.addTypeToList(outputType, result.types)); +} + +static LogicalResult verify(ReshapeOp op) { + // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank. + auto inputVectorType = op.getInputVectorType(); + auto outputVectorType = op.getOutputVectorType(); + int64_t inputShapeRank = op.getNumInputShapeSizes(); + int64_t outputShapeRank = op.getNumOutputShapeSizes(); + SmallVector fixedVectorSizes; + op.getFixedVectorSizes(fixedVectorSizes); + int64_t numFixedVectorSizes = fixedVectorSizes.size(); + + if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) + return op.emitError("invalid input shape for vector type ") + << inputVectorType; + + if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) + return op.emitError("invalid output shape for vector type ") + << outputVectorType; + + // Verify that the 'fixedVectorSizes' match a input/output vector shape + // suffix. + unsigned inputVectorRank = inputVectorType.getRank(); + for (unsigned i = 0; i < numFixedVectorSizes; ++i) { + unsigned index = inputVectorRank - numFixedVectorSizes - i; + if (fixedVectorSizes[i] != inputVectorType.getShape()[index]) + return op.emitError("fixed vector size must match input vector for dim ") + << i; + } + + unsigned outputVectorRank = outputVectorType.getRank(); + for (unsigned i = 0; i < numFixedVectorSizes; ++i) { + unsigned index = outputVectorRank - numFixedVectorSizes - i; + if (fixedVectorSizes[i] != outputVectorType.getShape()[index]) + return op.emitError("fixed vector size must match output vector for dim ") + << i; + } + + // If all shape operands are produced by constant ops, verify that product + // of dimensions for input/output shape match. + auto isDefByConstant = [](Value *operand) { + return isa_and_nonnull(operand->getDefiningOp()); + }; + if (llvm::all_of(op.input_shape(), isDefByConstant) && + llvm::all_of(op.output_shape(), isDefByConstant)) { + int64_t numInputElements = 1; + for (auto *operand : op.input_shape()) + numInputElements *= + cast(operand->getDefiningOp()).getValue(); + int64_t numOutputElements = 1; + for (auto *operand : op.output_shape()) + numOutputElements *= + cast(operand->getDefiningOp()).getValue(); + if (numInputElements != numOutputElements) + return op.emitError("product of input and output shape sizes must match"); + } + return success(); +} + +void ReshapeOp::getFixedVectorSizes(SmallVectorImpl &results) { + populateFromInt64AttrArray(fixed_vector_sizes(), results); +} + +//===----------------------------------------------------------------------===// // StridedSliceOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index d79c035..c208c92 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -826,3 +826,63 @@ func @print_no_result(%arg0 : f32) -> i32 { %0 = vector.print %arg0 : f32 return %0 } + +// ----- + +func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{invalid input shape for vector type}} + %1 = vector.reshape %arg0, [%c3, %c6, %c3], [%c2, %c9], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> +} + +// ----- + +func @reshape_bad_output_shape(%arg0 : vector<3x2x4xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{invalid output shape for vector type}} + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9, %c3], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> +} + +// ----- + +func @reshape_bad_input_output_shape_product(%arg0 : vector<3x2x4xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{product of input and output shape sizes must match}} + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c6], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> +} + +// ----- + +func @reshape_bad_input_fixed_size(%arg0 : vector<3x2x5xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{fixed vector size must match input vector for dim 0}} + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4] + : vector<3x2x5xf32> to vector<2x3x4xf32> +} + +// ----- + +func @reshape_bad_output_fixed_size(%arg0 : vector<3x2x4xf32>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c9 = constant 9 : index + // expected-error@+1 {{fixed vector size must match output vector for dim 0}} + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4] + : vector<3x2x4xf32> to vector<2x3x5xf32> +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 06d5728..e160799 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -205,3 +205,20 @@ func @vector_print(%arg0: vector<8x4xf32>) { vector.print %arg0 : vector<8x4xf32> return } + +// CHECK-LABEL: reshape +func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) { + // CHECK: %[[C2:.*]] = constant 2 : index + %c2 = constant 2 : index + // CHECK: %[[C3:.*]] = constant 3 : index + %c3 = constant 3 : index + // CHECK: %[[C6:.*]] = constant 6 : index + %c6 = constant 6 : index + // CHECK: %[[C9:.*]] = constant 9 : index + %c9 = constant 9 : index + // CHECK: vector.reshape %{{.*}}, [%[[C3]], %[[C6]]], [%[[C2]], %[[C9]]], [4] : vector<3x2x4xf32> to vector<2x3x4xf32> + %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> + + return %1 : vector<2x3x4xf32> +}