int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
+ /// Parse zero or more trailing SSA comma-separated trailing operand
+ /// references with a specified surrounding delimiter, and an optional
+ /// required operand count. A leading comma is expected before the operands.
+ virtual bool
+ parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
+ int requiredOperandCount = -1,
+ Delimiter delimiter = Delimiter::None) = 0;
+
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
// not be of the same dimensionality, but need to have the same elemental type.
// The operands include the source and destination memref's each followed by its
// indices, size of the data transfer in terms of the number of elements (of the
-// elemental type of the memref), and a tag memref with its indices. The tag
+// elemental type of the memref), a tag memref with its indices, and optionally
+// at the end, a stride and a number_of_elements_per_stride arguments. The tag
// location is used by a DmaWaitOp to check for completion. The indices of the
// source memref, destination memref, and the tag memref have the same
-// restrictions as any load/store in MLFunctions.
+// restrictions as any load/store in MLFunctions. The optional stride arguments
+// should be of 'index' type, and specify a stride for the slower memory space
+// (memory space with a lower memory space id), tranferring chunks of
+// number_of_elements_per_stride every stride until %num_elements are
+// transferred. Either both or no stride arguments should be specified.
//
// For example, a DmaStartOp operation that transfers 256 elements of a memref
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
// memref<1 x i32>, (d0) -> (d0), 2>
//
+// If %stride and %num_elt_per_stride are specified, the DMA is expected to
+// transfer %num_elt_per_stride elements every %stride elements apart from
+// memory space 0 until %num_elements are transferred.
+//
+// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
+// %num_elt_per_stride :
+//
+// TODO(mlir-team): add additional operands to allow source and destination
+// striding, and multiple stride levels.
// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
class DmaStartOp
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices,
SSAValue *numElements, SSAValue *tagMemRef,
- ArrayRef<SSAValue *> tagIndices);
+ ArrayRef<SSAValue *> tagIndices, SSAValue *stride = nullptr,
+ SSAValue *elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation.
const SSAValue *getSrcMemRef() const { return getOperand(0); }
const SSAValue *getTagMemRef() const {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
+ // Returns the rank (number of indices) of the tag MemRefType.
+ unsigned getTagMemRefRank() const {
+ return getTagMemRef()->getType().cast<MemRefType>().getRank();
+ }
+
// Returns the tag memref index for this DMA operation.
llvm::iterator_range<Operation::const_operand_iterator>
getTagIndices() const {
- return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 +
- getDstMemRefRank() + 1 + 1,
- getOperation()->operand_end()};
+ unsigned tagIndexStartPos =
+ 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
+ return {getOperation()->operand_begin() + tagIndexStartPos,
+ getOperation()->operand_begin() + tagIndexStartPos +
+ getTagMemRefRank()};
}
/// Returns true if this is a DMA from a faster memory space to a slower one.
static StringRef getOperationName() { return "dma_start"; }
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
+ bool verify() const;
+
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
+ bool isStrided() const {
+ return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
+ 1 + 1 + getTagMemRefRank();
+ }
+
+ SSAValue *getStride() {
+ if (!isStrided())
+ return nullptr;
+ return getOperand(getNumOperands() - 1 - 1);
+ }
+ const SSAValue *getStride() const {
+ return const_cast<DmaStartOp *>(this)->getStride();
+ }
+
+ SSAValue *getNumElementsPerStride() {
+ if (!isStrided())
+ return nullptr;
+ return getOperand(getNumOperands() - 1);
+ }
+ const SSAValue *getNumElementsPerStride() const {
+ return const_cast<DmaStartOp *>(this)->getNumElementsPerStride();
+ }
+
protected:
friend class ::mlir::Operation;
explicit DmaStartOp(const Operation *state) : Op(state) {}
return false;
}
+ bool parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
+ int requiredOperandCount,
+ Delimiter delimiter) override {
+ if (parser.getToken().is(Token::comma)) {
+ parseComma();
+ return parseOperandList(result, requiredOperandCount, delimiter);
+ }
+ if (requiredOperandCount != -1)
+ return emitError(parser.getToken().getLoc(),
+ "expected " + Twine(requiredOperandCount) + " operands");
+ return false;
+ }
+
/// Parse a keyword followed by a type.
bool parseKeywordType(const char *keyword, Type &result) override {
if (parser.getTokenSpelling() != keyword)
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices,
SSAValue *numElements, SSAValue *tagMemRef,
- ArrayRef<SSAValue *> tagIndices) {
+ ArrayRef<SSAValue *> tagIndices, SSAValue *stride,
+ SSAValue *elementsPerStride) {
result->addOperands(srcMemRef);
result->addOperands(srcIndices);
result->addOperands(destMemRef);
result->addOperands(numElements);
result->addOperands(tagMemRef);
result->addOperands(tagIndices);
+ if (stride) {
+ result->addOperands(stride);
+ result->addOperands(elementsPerStride);
+ }
}
void DmaStartOp::print(OpAsmPrinter *p) const {
*p << ", " << *getTagMemRef() << '[';
p->printOperands(getTagIndices());
*p << ']';
+ if (isStrided()) {
+ *p << ", " << *getStride();
+ *p << ", " << *getNumElementsPerStride();
+ }
p->printOptionalAttrDict(getAttrs());
*p << " : " << getSrcMemRef()->getType();
*p << ", " << getDstMemRef()->getType();
OpAsmParser::OperandType numElementsInfo;
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
+ SmallVector<OpAsmParser::OperandType, 2> strideInfo;
SmallVector<Type, 3> types;
auto indexType = parser->getBuilder().getIndexType();
parser->parseComma() || parser->parseOperand(numElementsInfo) ||
parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
parser->parseOperandList(tagIndexInfos, -1,
- OpAsmParser::Delimiter::Square) ||
- parser->parseColonTypeList(types))
+ OpAsmParser::Delimiter::Square))
+ return true;
+
+ // Parse optional stride and elements per stride.
+ if (parser->parseTrailingOperandList(strideInfo)) {
+ return true;
+ }
+ if (!strideInfo.empty() && strideInfo.size() != 2) {
+ return parser->emitError(parser->getNameLoc(),
+ "expected two stride related operands");
+ }
+ bool isStrided = strideInfo.size() == 2;
+
+ if (parser->parseColonTypeList(types))
return true;
if (types.size() != 3)
parser->resolveOperands(tagIndexInfos, indexType, result->operands))
return true;
+ if (isStrided) {
+ if (parser->resolveOperand(strideInfo[0], indexType, result->operands) ||
+ parser->resolveOperand(strideInfo[1], indexType, result->operands))
+ return true;
+ }
+
// Check that source/destination index list size matches associated rank.
if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
return false;
}
+bool DmaStartOp::verify() const {
+ // DMAs from different memory spaces supported.
+ if (getSrcMemorySpace() == getDstMemorySpace()) {
+ return emitOpError("DMA should be between different memory spaces");
+ }
+
+ if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
+ getDstMemRefRank() + 3 + 1 &&
+ getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
+ getDstMemRefRank() + 3 + 1 + 2) {
+ return emitOpError("incorrect number of operands");
+ }
+ return false;
+}
+
void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// dma_start(memrefcast) -> dma_start
return
}
+
+// CHECK-LABEL: mlfunc @dma_ops()
+mlfunc @dma_ops() {
+ %c0 = constant 0 : index
+ %stride = constant 32 : index
+ %elt_per_stride = constant 16 : index
+
+ %A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
+ %Ah = alloc() : memref<256 x f32, (d0) -> (d0), 1>
+ %tag = alloc() : memref<1 x f32>
+
+ %num_elements = constant 256 : index
+
+ dma_start %A[%c0], %Ah[%c0], %num_elements, %tag[%c0] : memref<256 x f32>, memref<256 x f32, 1>, memref<1 x f32>
+ dma_wait %tag[%c0], %num_elements : memref<1 x f32>
+ // CHECK: dma_start %0[%c0], %1[%c0], %c256, %2[%c0] : memref<256xf32>, memref<256xf32, 1>, memref<1xf32>
+ // CHECK-NEXT: dma_wait %2[%c0], %c256 : memref<1xf32>
+
+ // DMA with strides
+ dma_start %A[%c0], %Ah[%c0], %num_elements, %tag[%c0], %stride, %elt_per_stride : memref<256 x f32>, memref<256 x f32, 1>, memref<1 x f32>
+ dma_wait %tag[%c0], %num_elements : memref<1 x f32>
+ // CHECK-NEXT dma_start %0[%c0], %1[%c0], %c256, %2[%c0], %c32, %c16 : memref<256xf32>, memref<256xf32, 1>, memref<1xf32>
+ // CHECK-NEXT dma_wait %2[%c0], %c256 : memref<1xf32>
+
+ return
+}