From 9f77faae87dbb6231e91b7ed26589d828d48e7e9 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 5 Dec 2018 15:30:25 -0800 Subject: [PATCH] Strided DMA support for DmaStartOp - add optional stride arguments for DmaStartOp - add DmaStartOp::verify(), and missing test cases for DMA op's in test/IR/memory-ops.mlir. PiperOrigin-RevId: 224232466 --- mlir/include/mlir/IR/OpImplementation.h | 8 +++ mlir/include/mlir/StandardOps/StandardOps.h | 59 ++++++++++++++++++--- mlir/lib/Parser/Parser.cpp | 13 +++++ mlir/lib/StandardOps/StandardOps.cpp | 49 +++++++++++++++-- mlir/test/IR/memory-ops.mlir | 26 +++++++++ 5 files changed, 146 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index beb1dc0c643a..3a2f633c88ca 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -293,6 +293,14 @@ public: int requiredOperandCount = -1, Delimiter delimiter = Delimiter::None) = 0; + /// Parse zero or more trailing SSA comma-separated trailing operand + /// references with a specified surrounding delimiter, and an optional + /// required operand count. A leading comma is expected before the operands. + virtual bool + parseTrailingOperandList(SmallVectorImpl &result, + int requiredOperandCount = -1, + Delimiter delimiter = Delimiter::None) = 0; + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index ffd903aa779e..562c932adbc4 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -316,10 +316,15 @@ private: // not be of the same dimensionality, but need to have the same elemental type. // The operands include the source and destination memref's each followed by its // indices, size of the data transfer in terms of the number of elements (of the -// elemental type of the memref), and a tag memref with its indices. The tag +// elemental type of the memref), a tag memref with its indices, and optionally +// at the end, a stride and a number_of_elements_per_stride arguments. The tag // location is used by a DmaWaitOp to check for completion. The indices of the // source memref, destination memref, and the tag memref have the same -// restrictions as any load/store in MLFunctions. +// restrictions as any load/store in MLFunctions. The optional stride arguments +// should be of 'index' type, and specify a stride for the slower memory space +// (memory space with a lower memory space id), tranferring chunks of +// number_of_elements_per_stride every stride until %num_elements are +// transferred. Either both or no stride arguments should be specified. // // For example, a DmaStartOp operation that transfers 256 elements of a memref // '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space @@ -333,6 +338,15 @@ private: // memref<2 x 1024 x f32>, (d0) -> (d0), 1>, // memref<1 x i32>, (d0) -> (d0), 2> // +// If %stride and %num_elt_per_stride are specified, the DMA is expected to +// transfer %num_elt_per_stride elements every %stride elements apart from +// memory space 0 until %num_elements are transferred. +// +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, +// %num_elt_per_stride : +// +// TODO(mlir-team): add additional operands to allow source and destination +// striding, and multiple stride levels. // TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. class DmaStartOp : public Op { @@ -341,7 +355,8 @@ public: SSAValue *srcMemRef, ArrayRef srcIndices, SSAValue *destMemRef, ArrayRef destIndices, SSAValue *numElements, SSAValue *tagMemRef, - ArrayRef tagIndices); + ArrayRef tagIndices, SSAValue *stride = nullptr, + SSAValue *elementsPerStride = nullptr); // Returns the source MemRefType for this DMA operation. const SSAValue *getSrcMemRef() const { return getOperand(0); } @@ -388,12 +403,19 @@ public: const SSAValue *getTagMemRef() const { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); } + // Returns the rank (number of indices) of the tag MemRefType. + unsigned getTagMemRefRank() const { + return getTagMemRef()->getType().cast().getRank(); + } + // Returns the tag memref index for this DMA operation. llvm::iterator_range getTagIndices() const { - return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + - getDstMemRefRank() + 1 + 1, - getOperation()->operand_end()}; + unsigned tagIndexStartPos = + 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; + return {getOperation()->operand_begin() + tagIndexStartPos, + getOperation()->operand_begin() + tagIndexStartPos + + getTagMemRefRank()}; } /// Returns true if this is a DMA from a faster memory space to a slower one. @@ -418,9 +440,34 @@ public: static StringRef getOperationName() { return "dma_start"; } static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p) const; + bool verify() const; + static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); + bool isStrided() const { + return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + + 1 + 1 + getTagMemRefRank(); + } + + SSAValue *getStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1 - 1); + } + const SSAValue *getStride() const { + return const_cast(this)->getStride(); + } + + SSAValue *getNumElementsPerStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1); + } + const SSAValue *getNumElementsPerStride() const { + return const_cast(this)->getNumElementsPerStride(); + } + protected: friend class ::mlir::Operation; explicit DmaStartOp(const Operation *state) : Op(state) {} diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index a1a97b89663e..26e80eefb163 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2304,6 +2304,19 @@ public: return false; } + bool parseTrailingOperandList(SmallVectorImpl &result, + int requiredOperandCount, + Delimiter delimiter) override { + if (parser.getToken().is(Token::comma)) { + parseComma(); + return parseOperandList(result, requiredOperandCount, delimiter); + } + if (requiredOperandCount != -1) + return emitError(parser.getToken().getLoc(), + "expected " + Twine(requiredOperandCount) + " operands"); + return false; + } + /// Parse a keyword followed by a type. bool parseKeywordType(const char *keyword, Type &result) override { if (parser.getTokenSpelling() != keyword) diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index e46fd2319dab..4f837542046d 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -686,7 +686,8 @@ void DmaStartOp::build(Builder *builder, OperationState *result, SSAValue *srcMemRef, ArrayRef srcIndices, SSAValue *destMemRef, ArrayRef destIndices, SSAValue *numElements, SSAValue *tagMemRef, - ArrayRef tagIndices) { + ArrayRef tagIndices, SSAValue *stride, + SSAValue *elementsPerStride) { result->addOperands(srcMemRef); result->addOperands(srcIndices); result->addOperands(destMemRef); @@ -694,6 +695,10 @@ void DmaStartOp::build(Builder *builder, OperationState *result, result->addOperands(numElements); result->addOperands(tagMemRef); result->addOperands(tagIndices); + if (stride) { + result->addOperands(stride); + result->addOperands(elementsPerStride); + } } void DmaStartOp::print(OpAsmPrinter *p) const { @@ -705,6 +710,10 @@ void DmaStartOp::print(OpAsmPrinter *p) const { *p << ", " << *getTagMemRef() << '['; p->printOperands(getTagIndices()); *p << ']'; + if (isStrided()) { + *p << ", " << *getStride(); + *p << ", " << *getNumElementsPerStride(); + } p->printOptionalAttrDict(getAttrs()); *p << " : " << getSrcMemRef()->getType(); *p << ", " << getDstMemRef()->getType(); @@ -728,6 +737,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType numElementsInfo; OpAsmParser::OperandType tagMemrefInfo; SmallVector tagIndexInfos; + SmallVector strideInfo; SmallVector types; auto indexType = parser->getBuilder().getIndexType(); @@ -745,8 +755,20 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseComma() || parser->parseOperand(numElementsInfo) || parser->parseComma() || parser->parseOperand(tagMemrefInfo) || parser->parseOperandList(tagIndexInfos, -1, - OpAsmParser::Delimiter::Square) || - parser->parseColonTypeList(types)) + OpAsmParser::Delimiter::Square)) + return true; + + // Parse optional stride and elements per stride. + if (parser->parseTrailingOperandList(strideInfo)) { + return true; + } + if (!strideInfo.empty() && strideInfo.size() != 2) { + return parser->emitError(parser->getNameLoc(), + "expected two stride related operands"); + } + bool isStrided = strideInfo.size() == 2; + + if (parser->parseColonTypeList(types)) return true; if (types.size() != 3) @@ -763,6 +785,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { parser->resolveOperands(tagIndexInfos, indexType, result->operands)) return true; + if (isStrided) { + if (parser->resolveOperand(strideInfo[0], indexType, result->operands) || + parser->resolveOperand(strideInfo[1], indexType, result->operands)) + return true; + } + // Check that source/destination index list size matches associated rank. if (srcIndexInfos.size() != types[0].cast().getRank() || dstIndexInfos.size() != types[1].cast().getRank()) @@ -776,6 +804,21 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { return false; } +bool DmaStartOp::verify() const { + // DMAs from different memory spaces supported. + if (getSrcMemorySpace() == getDstMemorySpace()) { + return emitOpError("DMA should be between different memory spaces"); + } + + if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + + getDstMemRefRank() + 3 + 1 && + getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + + getDstMemRefRank() + 3 + 1 + 2) { + return emitOpError("incorrect number of operands"); + } + return false; +} + void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_start(memrefcast) -> dma_start diff --git a/mlir/test/IR/memory-ops.mlir b/mlir/test/IR/memory-ops.mlir index 7d0c238e94e4..a70c7f5dd56f 100644 --- a/mlir/test/IR/memory-ops.mlir +++ b/mlir/test/IR/memory-ops.mlir @@ -61,3 +61,29 @@ bb0: return } + +// CHECK-LABEL: mlfunc @dma_ops() +mlfunc @dma_ops() { + %c0 = constant 0 : index + %stride = constant 32 : index + %elt_per_stride = constant 16 : index + + %A = alloc() : memref<256 x f32, (d0) -> (d0), 0> + %Ah = alloc() : memref<256 x f32, (d0) -> (d0), 1> + %tag = alloc() : memref<1 x f32> + + %num_elements = constant 256 : index + + dma_start %A[%c0], %Ah[%c0], %num_elements, %tag[%c0] : memref<256 x f32>, memref<256 x f32, 1>, memref<1 x f32> + dma_wait %tag[%c0], %num_elements : memref<1 x f32> + // CHECK: dma_start %0[%c0], %1[%c0], %c256, %2[%c0] : memref<256xf32>, memref<256xf32, 1>, memref<1xf32> + // CHECK-NEXT: dma_wait %2[%c0], %c256 : memref<1xf32> + + // DMA with strides + dma_start %A[%c0], %Ah[%c0], %num_elements, %tag[%c0], %stride, %elt_per_stride : memref<256 x f32>, memref<256 x f32, 1>, memref<1 x f32> + dma_wait %tag[%c0], %num_elements : memref<1 x f32> + // CHECK-NEXT dma_start %0[%c0], %1[%c0], %c256, %2[%c0], %c32, %c16 : memref<256xf32>, memref<256xf32, 1>, memref<1xf32> + // CHECK-NEXT dma_wait %2[%c0], %c256 : memref<1xf32> + + return +} -- 2.34.1