}];
}
+def Vector_ExtractSlicesOp :
+ Vector_Op<"extract_slices", [NoSideEffect]>,
+ Arguments<(ins AnyVector:$vector, I64ArrayAttr:$sizes,
+ I64ArrayAttr:$strides)>,
+ Results<(outs TupleOf<[AnyVector]>)> {
+ let summary = "vector extract slices operation";
+ let description = [{
+ Takes an N-d vector and returns a tuple of vector slices of 'vector',
+ based on 'sizes' and 'strides' parameters.
+
+ The arguments 'sizes' and 'strides' represent a specification for
+ generating the unrolling of 'vector' shape, which has all slices of shape
+ 'sizes' except for slices at dimension boundaries when 'vector' dimension
+ sizes are not a multiple of 'sizes'.
+
+ Each slice is returned at the tuple element index corresponding to the
+ linear index of the slice w.r.t the unrolling scheme represented by 'sizes'.
+ Currently, only unit strides are supported.
+
+ Examples:
+ ```
+ %0 = vector.transfer_read ...: vector<4x2xf32>
+
+ %1 = vector.extract_slices %0, [2, 2], [1, 1]
+ : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
+
+ // Example with partial slices at dimension boundaries.
+ %2 = vector.transfer_read ...: vector<4x3xf32>
+
+ %3 = vector.extract_slices %2, [2, 2], [1, 1]
+ : vector<4x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>,
+ vector<2x2xf32>, vector<2x2xf32>>
+ ```
+ }];
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, TupleType tupleType, " #
+ "Value *vector, ArrayRef<int64_t> sizes, " #
+ "ArrayRef<int64_t> strides">];
+ let extraClassDeclaration = [{
+ VectorType getSourceVectorType() {
+ return vector()->getType().cast<VectorType>();
+ }
+ TupleType getResultTupleType() {
+ return getResult()->getType().cast<TupleType>();
+ }
+ void getSizes(SmallVectorImpl<int64_t> &results);
+ void getStrides(SmallVectorImpl<int64_t> &results);
+ static StringRef getSizesAttrName() { return "sizes"; }
+ static StringRef getStridesAttrName() { return "strides"; }
+ }];
+}
+
def Vector_InsertOp :
Vector_Op<"insert", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
let hasCanonicalizer = 1;
}
+def Vector_TupleOp :
+ Vector_Op<"tuple", [NoSideEffect]>,
+ Arguments<(ins Variadic<AnyVector>:$vectors)>,
+ Results<(outs TupleOf<[AnyVector]>)> {
+ let summary = "make tuple of vectors operation";
+ let description = [{
+ Returns a tuple of its operands 'vectors'.
+
+ Note that this operation is used during the vector op unrolling
+ transformation and should be removed before lowering to lower-level
+ dialects.
+
+
+ Examples:
+ ```
+ %0 = vector.transfer_read ... : vector<2x2xf32>
+ %1 = vector.transfer_read ... : vector<2x1xf32>
+ %2 = vector.transfer_read ... : vector<2x2xf32>
+ %3 = vector.transfer_read ... : vector<2x1xf32>
+
+ %4 = vector.tuple %0, %1, %2, %3
+ : vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>
+
+ ```
+ }];
+
+ let extraClassDeclaration = [{
+ TupleType getResultTupleType() {
+ return getResult()->getType().cast<TupleType>();
+ }
+ }];
+}
+
+def Vector_TupleGetOp :
+ Vector_Op<"tuple_get", [NoSideEffect]>,
+ Arguments<(ins TupleOf<[AnyVector]>:$vectors, APIntAttr:$index)>,
+ Results<(outs AnyVector)> {
+ let summary = "vector tuple get operation";
+ let description = [{
+ Returns the tuple element of 'vectors' at 'index'.
+
+ Note that this operation is used during the vector op unrolling
+ transformation and should be removed before lowering to lower-level
+ dialects.
+
+ Examples:
+ ```
+ %4 = vector.tuple %0, %1, %2, %3
+ : vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>>
+
+ %5 = vector.tuple_get %4, 1
+ : tuple<vector<2x2xf32>, vector<2x1xf32>,
+ vector<2x2xf32>, vector<2x1xf32>>
+ ```
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getResultVectorType() {
+ return getResult()->getType().cast<VectorType>();
+ }
+ unsigned getIndex() {
+ return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
+ }
+ static StringRef getIndexAttrName() { return "index"; }
+ }];
+}
+
#endif // VECTOR_OPS
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
}
//===----------------------------------------------------------------------===//
+// ExtractSlicesOp
+//===----------------------------------------------------------------------===//
+
+void ExtractSlicesOp::build(Builder *builder, OperationState &result,
+ TupleType tupleType, Value *vector,
+ ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides) {
+ result.addOperands(vector);
+ auto sizesAttr = builder->getI64ArrayAttr(sizes);
+ auto stridesAttr = builder->getI64ArrayAttr(strides);
+ result.addTypes(tupleType);
+ result.addAttribute(getSizesAttrName(), sizesAttr);
+ result.addAttribute(getStridesAttrName(), stridesAttr);
+}
+
+static ParseResult parseExtractSlicesOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType operandInfo;
+ ArrayAttr sizesAttr;
+ StringRef sizesAttrName = ExtractSlicesOp::getSizesAttrName();
+ ArrayAttr stridesAttr;
+ StringRef stridesAttrName = ExtractSlicesOp::getStridesAttrName();
+ VectorType vectorType;
+ TupleType resultTupleType;
+ return failure(
+ parser.parseOperand(operandInfo) || parser.parseComma() ||
+ parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
+ parser.parseComma() ||
+ parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(vectorType) ||
+ parser.parseKeywordType("into", resultTupleType) ||
+ parser.resolveOperand(operandInfo, vectorType, result.operands) ||
+ parser.addTypeToList(resultTupleType, result.types));
+}
+
+static void print(OpAsmPrinter &p, ExtractSlicesOp op) {
+ p << op.getOperationName() << ' ' << *op.vector() << ", ";
+ p << op.sizes() << ", " << op.strides();
+ p.printOptionalAttrDict(
+ op.getAttrs(),
+ /*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(),
+ ExtractSlicesOp::getStridesAttrName()});
+ p << " : " << op.vector()->getType();
+ p << " into " << op.getResultTupleType();
+}
+
+static LogicalResult
+isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
+ TupleType tupleType, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides) {
+ // Check for non-unit strides.
+ // TODO(b/144845578) Support non-1 strides.
+ if (llvm::any_of(strides, [](int64_t s) { return s != 1; }))
+ return op->emitError("requires unit strides");
+ // Check that 'vectorType' rank matches rank of tuple element vectors.
+ unsigned rank = vectorType.getRank();
+ auto is_vector_type_of_rank = [&](Type t) {
+ return t.isa<VectorType>() && t.cast<VectorType>().getRank() == rank;
+ };
+ if (!llvm::all_of(tupleType.getTypes(), is_vector_type_of_rank))
+ return op->emitError("requires vector tuple elements of rank ") << rank;
+ // Check that 'sizes' and 'strides' are of size == 'rank'.
+ if (sizes.size() != rank || strides.size() != rank)
+ return op->emitError("requires sizes and strides of rank ") << rank;
+
+ // Compute the number of slices in each dimension.
+ // TODO(andydavis) Move this into a slice generation helper function.
+ auto shape = vectorType.getShape();
+ SmallVector<int64_t, 4> dimSliceCounts(rank);
+ for (unsigned i = 0; i < rank; ++i)
+ dimSliceCounts[i] = ceilDiv(shape[i], sizes[i]);
+ // Compute the strides between slices in each dimension.
+ SmallVector<int64_t, 4> sliceStrides(rank);
+ sliceStrides[rank - 1] = 1;
+ for (int i = rank - 2; i >= 0; --i)
+ sliceStrides[i] = sliceStrides[i + 1] * dimSliceCounts[i + 1];
+
+ // Generate each slice shape based on 'sizes', 'strides' and 'vectorType',
+ // and varify that the same matches the corresponding tuple element 'i'.
+ for (int64_t i = 0, e = tupleType.size(); i < e; ++i) {
+ // De-linearize w.r.t. 'sliceStrides'.
+ SmallVector<int64_t, 4> vectorOffsets(rank);
+ int64_t linearIndex = i;
+ for (unsigned j = 0; j < rank; ++j) {
+ vectorOffsets.push_back(linearIndex / sliceStrides[i]);
+ linearIndex %= sliceStrides[i];
+ }
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto offsets = mlir::functional::zipMap(
+ [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
+ // Initialize 'sliceSizes' to target 'sizes'
+ SmallVector<int64_t, 4> sliceSizes(sizes.begin(), sizes.end());
+ for (unsigned j = 0; j < rank; ++j) {
+ // Based on 'offsets' and 'shape' clip some dim sizes for partial tiles.
+ sliceSizes[j] = std::min(sliceSizes[j], shape[j] - offsets[j]);
+ }
+ // Create slice VectorType type.
+ auto sliceVectorType =
+ VectorType::get(sliceSizes, vectorType.getElementType());
+ // Verify that 'sliceVectorType' matches tupleType.getTypes(i)
+ if (sliceVectorType != tupleType.getType(i))
+ return op->emitError("invalid tuple element type ") << sliceVectorType;
+ }
+ return success();
+}
+
+static LogicalResult verify(ExtractSlicesOp op) {
+ SmallVector<int64_t, 4> sizes;
+ op.getSizes(sizes);
+ SmallVector<int64_t, 4> strides;
+ op.getStrides(strides);
+ return isValidExtractOrInsertSlicesType(
+ op.getOperation(), op.getSourceVectorType(), op.getResultTupleType(),
+ sizes, strides);
+}
+
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+ SmallVectorImpl<int64_t> &results) {
+ for (auto attr : arrayAttr)
+ results.push_back(attr.cast<IntegerAttr>().getInt());
+}
+
+void ExtractSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
+ populateFromInt64AttrArray(sizes(), results);
+}
+
+void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
+ populateFromInt64AttrArray(strides(), results);
+}
+
+//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
return success();
}
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
- SmallVectorImpl<int64_t> &results) {
- for (auto attr : arrayAttr)
- results.push_back(attr.cast<IntegerAttr>().getInt());
-}
-
static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
MLIRContext *context) {
auto attrs = functional::map(
}
//===----------------------------------------------------------------------===//
+// TupleOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 4> operandInfos;
+ SmallVector<Type, 4> types;
+ auto loc = parser.getCurrentLocation();
+ auto *ctx = parser.getBuilder().getContext();
+ return failure(
+ parser.parseOperandList(operandInfos) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonTypeList(types) ||
+ parser.resolveOperands(operandInfos, types, loc, result.operands) ||
+ parser.addTypeToList(TupleType::get(types, ctx), result.types));
+}
+
+static void print(OpAsmPrinter &p, TupleOp op) {
+ p << op.getOperationName() << ' ';
+ p.printOperands(op.getOperands());
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : ";
+ interleaveComma(op.getOperation()->getOperandTypes(), p);
+}
+
+static LogicalResult verify(TupleOp op) { return success(); }
+
+//===----------------------------------------------------------------------===//
+// TupleGetOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTupleGetOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType operandInfo;
+ IntegerAttr indexAttr;
+ StringRef indexAttrName = TupleGetOp::getIndexAttrName();
+ Type indexType = parser.getBuilder().getIndexType();
+ TupleType tupleType;
+ VectorType resultVectorType;
+ if (parser.parseOperand(operandInfo) || parser.parseComma() ||
+ parser.parseAttribute(indexAttr, indexType, indexAttrName,
+ result.attributes) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(tupleType) ||
+ parser.resolveOperand(operandInfo, tupleType, result.operands))
+ return failure();
+ if (indexAttr.getInt() < 0 ||
+ indexAttr.getInt() >= static_cast<int64_t>(tupleType.size()))
+ return failure();
+ parser.addTypeToList(tupleType.getType(indexAttr.getInt()), result.types);
+ return success();
+}
+
+static void print(OpAsmPrinter &p, TupleGetOp op) {
+ p << op.getOperationName() << ' ' << *op.getOperand() << ", " << op.index();
+ p.printOptionalAttrDict(op.getAttrs(),
+ /*elidedAttrs=*/{TupleGetOp::getIndexAttrName()});
+ p << " : " << op.getOperand()->getType();
+}
+
+static LogicalResult verify(TupleGetOp op) {
+ auto tupleType = op.getOperand()->getType().cast<TupleType>();
+ if (op.getIndex() < 0 || op.getIndex() >= tupleType.size())
+ return op.emitOpError("tuple get index out of range");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ConstantMaskOp
//===----------------------------------------------------------------------===//