}
//===----------------------------------------------------------------------===//
+// BaseOpWithOffsetSizesAndStridesOp
+//===----------------------------------------------------------------------===//
+
+/// Print a list with either (1) the static integer value in `arrayAttr` if
+/// `isDynamic` evaluates to false or (2) the next value otherwise.
+/// This allows idiomatic printing of mixed value and integer attributes in a
+/// list. E.g. `[%arg0, 7, 42, %arg42]`.
+static void
+printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
+ ArrayAttr arrayAttr,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
+ p << '[';
+ unsigned idx = 0;
+ llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
+ int64_t val = a.cast<IntegerAttr>().getInt();
+ if (isDynamic(val))
+ p << values[idx++];
+ else
+ p << val;
+ });
+ p << ']';
+}
+
+/// Parse a mixed list with either (1) static integer values or (2) SSA values.
+/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
+/// encode the position of SSA values. Add the parsed SSA values to `ssa`
+/// in-order.
+//
+/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
+/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
+/// 2. `ssa` is filled with "[%arg0, %arg1]".
+static ParseResult
+parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
+ StringRef attrName, int64_t dynVal,
+ SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
+ if (failed(parser.parseLSquare()))
+ return failure();
+ // 0-D.
+ if (succeeded(parser.parseOptionalRSquare())) {
+ result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
+ return success();
+ }
+
+ SmallVector<int64_t, 4> attrVals;
+ while (true) {
+ OpAsmParser::OperandType operand;
+ auto res = parser.parseOptionalOperand(operand);
+ if (res.hasValue() && succeeded(res.getValue())) {
+ ssa.push_back(operand);
+ attrVals.push_back(dynVal);
+ } else {
+ IntegerAttr attr;
+ if (failed(parser.parseAttribute<IntegerAttr>(attr)))
+ return parser.emitError(parser.getNameLoc())
+ << "expected SSA value or integer";
+ attrVals.push_back(attr.getInt());
+ }
+
+ if (succeeded(parser.parseOptionalComma()))
+ continue;
+ if (failed(parser.parseRSquare()))
+ return failure();
+ break;
+ }
+
+ auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
+ result.addAttribute(attrName, arrayAttr);
+ return success();
+}
+
+/// Verify that a particular offset/size/stride static attribute is well-formed.
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
+ OpType op, StringRef name, unsigned expectedNumElements, StringRef attrName,
+ ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
+ ValueRange values) {
+ /// Check static and dynamic offsets/sizes/strides breakdown.
+ if (attr.size() != expectedNumElements)
+ return op.emitError("expected ")
+ << expectedNumElements << " " << name << " values";
+ unsigned expectedNumDynamicEntries =
+ llvm::count_if(attr.getValue(), [&](Attribute attr) {
+ return isDynamic(attr.cast<IntegerAttr>().getInt());
+ });
+ if (values.size() != expectedNumDynamicEntries)
+ return op.emitError("expected ")
+ << expectedNumDynamicEntries << " dynamic " << name << " values";
+ return success();
+}
+
+/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
+static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+ return llvm::to_vector<4>(
+ llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+ return a.cast<IntegerAttr>().getInt();
+ }));
+}
+
+/// Verify static attributes offsets/sizes/strides.
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
+ unsigned srcRank = op.getSourceRank();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "offset", srcRank, op.getStaticOffsetsAttrName(),
+ op.static_offsets(), ShapedType::isDynamicStrideOrOffset,
+ op.offsets())))
+ return failure();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(),
+ ShapedType::isDynamic, op.sizes())))
+ return failure();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "stride", srcRank, op.getStaticStridesAttrName(),
+ op.static_strides(), ShapedType::isDynamicStrideOrOffset,
+ op.strides())))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
}
//===----------------------------------------------------------------------===//
+// MemRefReinterpretCastOp
+//===----------------------------------------------------------------------===//
+
+/// Print of the form:
+/// ```
+/// `name` ssa-name to
+/// offset: `[` offset `]`
+/// sizes: `[` size-list `]`
+/// strides:`[` stride-list `]`
+/// `:` any-memref-type to strided-memref-type
+/// ```
+static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
+ int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op.getOperationName().drop_front(stdDotLen) << " " << op.source()
+ << " to offset: ";
+ printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset);
+ p << ", sizes: ";
+ printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+ ShapedType::isDynamic);
+ p << ", strides: ";
+ printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
+ ShapedType::isDynamicStrideOrOffset);
+ p.printOptionalAttrDict(
+ op.getAttrs(),
+ /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(),
+ MemRefReinterpretCastOp::getStaticOffsetsAttrName(),
+ MemRefReinterpretCastOp::getStaticSizesAttrName(),
+ MemRefReinterpretCastOp::getStaticStridesAttrName()});
+ p << ": " << op.source().getType() << " to " << op.getType();
+}
+
+/// Parse of the form:
+/// ```
+/// `name` ssa-name to
+/// offset: `[` offset `]`
+/// sizes: `[` size-list `]`
+/// strides:`[` stride-list `]`
+/// `:` any-memref-type to strided-memref-type
+/// ```
+static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
+ OperationState &result) {
+ // Parse `operand` and `offset`.
+ OpAsmParser::OperandType operand;
+ if (parser.parseOperand(operand))
+ return failure();
+
+ // Parse offset.
+ SmallVector<OpAsmParser::OperandType, 1> offset;
+ if (parser.parseKeyword("to") || parser.parseKeyword("offset") ||
+ parser.parseColon() ||
+ parseListOfOperandsOrIntegers(
+ parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(),
+ ShapedType::kDynamicStrideOrOffset, offset) ||
+ parser.parseComma())
+ return failure();
+
+ // Parse `sizes`.
+ SmallVector<OpAsmParser::OperandType, 4> sizes;
+ if (parser.parseKeyword("sizes") || parser.parseColon() ||
+ parseListOfOperandsOrIntegers(
+ parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(),
+ ShapedType::kDynamicSize, sizes) ||
+ parser.parseComma())
+ return failure();
+
+ // Parse `strides`.
+ SmallVector<OpAsmParser::OperandType, 4> strides;
+ if (parser.parseKeyword("strides") || parser.parseColon() ||
+ parseListOfOperandsOrIntegers(
+ parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(),
+ ShapedType::kDynamicStrideOrOffset, strides))
+ return failure();
+
+ // Handle segment sizes.
+ auto b = parser.getBuilder();
+ SmallVector<int, 4> segmentSizes = {1, static_cast<int>(offset.size()),
+ static_cast<int>(sizes.size()),
+ static_cast<int>(strides.size())};
+ result.addAttribute(MemRefReinterpretCastOp::getOperandSegmentSizeAttr(),
+
+ b.getI32VectorAttr(segmentSizes));
+
+ // Parse types and resolve.
+ Type indexType = b.getIndexType();
+ Type operandType, resultType;
+ return failure(
+ (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(operandType) || parser.parseKeyword("to") ||
+ parser.parseType(resultType) ||
+ parser.resolveOperand(operand, operandType, result.operands) ||
+ parser.resolveOperands(offset, indexType, result.operands) ||
+ parser.resolveOperands(sizes, indexType, result.operands) ||
+ parser.resolveOperands(strides, indexType, result.operands) ||
+ parser.addTypeToList(resultType, result.types)));
+}
+
+static LogicalResult verify(MemRefReinterpretCastOp op) {
+ // The source and result memrefs should be in the same memory space.
+ auto srcType = op.source().getType().cast<BaseMemRefType>();
+ auto resultType = op.getType().cast<MemRefType>();
+ if (srcType.getMemorySpace() != resultType.getMemorySpace())
+ return op.emitError("different memory spaces specified for source type ")
+ << srcType << " and result memref type " << resultType;
+ if (srcType.getElementType() != resultType.getElementType())
+ return op.emitError("different element types specified for source type ")
+ << srcType << " and result memref type " << resultType;
+
+ // Verify that dynamic and static offset/sizes/strides arguments/attributes
+ // are consistent.
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset, op.offsets())))
+ return failure();
+ unsigned resultRank = op.getResultRank();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "size", resultRank, op.getStaticSizesAttrName(),
+ op.static_sizes(), ShapedType::isDynamic, op.sizes())))
+ return failure();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "stride", resultRank, op.getStaticStridesAttrName(),
+ op.static_strides(), ShapedType::isDynamicStrideOrOffset,
+ op.strides())))
+ return failure();
+
+ // Extract source offset and strides.
+ int64_t resultOffset;
+ SmallVector<int64_t, 4> resultStrides;
+ if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+ return failure();
+
+ // Match offset in result memref type and in static_offsets attribute.
+ int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
+ if (resultOffset != expectedOffset)
+ return op.emitError("expected result type with offset = ")
+ << resultOffset << " instead of " << expectedOffset;
+
+ // Match sizes in result memref type and in static_sizes attribute.
+ for (auto &en :
+ llvm::enumerate(llvm::zip(resultType.getShape(),
+ extractFromI64ArrayAttr(op.static_sizes())))) {
+ int64_t resultSize = std::get<0>(en.value());
+ int64_t expectedSize = std::get<1>(en.value());
+ if (resultSize != expectedSize)
+ return op.emitError("expected result type with size = ")
+ << expectedSize << " instead of " << resultSize
+ << " in dim = " << en.index();
+ }
+
+ // Match strides in result memref type and in static_strides attribute.
+ for (auto &en : llvm::enumerate(llvm::zip(
+ resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
+ int64_t resultStride = std::get<0>(en.value());
+ int64_t expectedStride = std::get<1>(en.value());
+ if (resultStride != expectedStride)
+ return op.emitError("expected result type with stride = ")
+ << expectedStride << " instead of " << resultStride
+ << " in dim = " << en.index();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// MemRefReshapeOp
//===----------------------------------------------------------------------===//
// SubViewOp
//===----------------------------------------------------------------------===//
-/// Print a list with either (1) the static integer value in `arrayAttr` if
-/// `isDynamic` evaluates to false or (2) the next value otherwise.
-/// This allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-static void printSubViewListOfOperandsOrIntegers(
- OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
- llvm::function_ref<bool(int64_t)> isDynamic) {
- p << "[";
- unsigned idx = 0;
- llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
- int64_t val = a.cast<IntegerAttr>().getInt();
- if (isDynamic(val))
- p << values[idx++];
- else
- p << val;
- });
- p << "] ";
-}
-
-/// Parse a mixed list with either (1) static integer values or (2) SSA values.
-/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
-/// encode the position of SSA values. Add the parsed SSA values to `ssa`
-/// in-order.
-//
-/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
-/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
-/// 2. `ssa` is filled with "[%arg0, %arg1]".
-static ParseResult
-parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
- StringRef attrName, int64_t dynVal,
- SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
- if (failed(parser.parseLSquare()))
- return failure();
- // 0-D.
- if (succeeded(parser.parseOptionalRSquare())) {
- result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
- return success();
- }
-
- SmallVector<int64_t, 4> attrVals;
- while (true) {
- OpAsmParser::OperandType operand;
- auto res = parser.parseOptionalOperand(operand);
- if (res.hasValue() && succeeded(res.getValue())) {
- ssa.push_back(operand);
- attrVals.push_back(dynVal);
- } else {
- Attribute attr;
- NamedAttrList placeholder;
- if (failed(parser.parseAttribute(attr, "_", placeholder)) ||
- !attr.isa<IntegerAttr>())
- return parser.emitError(parser.getNameLoc())
- << "expected SSA value or integer";
- attrVals.push_back(attr.cast<IntegerAttr>().getInt());
- }
-
- if (succeeded(parser.parseOptionalComma()))
- continue;
- if (failed(parser.parseRSquare()))
- return failure();
- else
- break;
- }
-
- auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
- result.addAttribute(attrName, arrayAttr);
- return success();
-}
-
namespace {
/// Helpers to write more idiomatic operations.
namespace saturated_arith {
p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
p << op.source();
printExtraOperands(p, op);
- printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset);
- printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
- ShapedType::isDynamic);
- printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset);
+ printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset);
+ p << ' ';
+ printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+ ShapedType::isDynamic);
+ p << ' ';
+ printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
+ ShapedType::isDynamicStrideOrOffset);
+ p << ' ';
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{OpType::getSpecialAttrNames()});
p << " : " << op.getSourceType() << " " << resultTypeKeyword << " "
/// For ViewLikeOpInterface.
Value SubViewOp::getViewSource() { return source(); }
-/// Verify that a particular offset/size/stride static attribute is well-formed.
-template <typename OpType>
-static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
- OpType op, StringRef name, StringRef attrName, ArrayAttr attr,
- llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
- /// Check static and dynamic offsets/sizes/strides breakdown.
- if (attr.size() != op.getSourceRank())
- return op.emitError("expected ")
- << op.getSourceRank() << " " << name << " values";
- unsigned expectedNumDynamicEntries =
- llvm::count_if(attr.getValue(), [&](Attribute attr) {
- return isDynamic(attr.cast<IntegerAttr>().getInt());
- });
- if (values.size() != expectedNumDynamicEntries)
- return op.emitError("expected ")
- << expectedNumDynamicEntries << " dynamic " << name << " values";
- return success();
-}
-
-/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
- return llvm::to_vector<4>(
- llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
-}
-
llvm::Optional<SmallVector<bool, 4>>
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape) {
llvm_unreachable("unexpected subview verification result");
}
-template <typename OpType>
-static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
- // Verify static attributes offsets/sizes/strides.
- if (failed(verifyOpWithOffsetSizesAndStridesPart(
- op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset, op.offsets())))
- return failure();
-
- if (failed(verifyOpWithOffsetSizesAndStridesPart(
- op, "size", op.getStaticSizesAttrName(), op.static_sizes(),
- ShapedType::isDynamic, op.sizes())))
- return failure();
- if (failed(verifyOpWithOffsetSizesAndStridesPart(
- op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset, op.strides())))
- return failure();
- return success();
-}
/// Verifier for SubViewOp.
static LogicalResult verify(SubViewOp op) {