Strided DMA support for DmaStartOp
authorUday Bondhugula <bondhugula@google.com>
Wed, 5 Dec 2018 23:30:25 +0000 (15:30 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:18:37 +0000 (14:18 -0700)
- add optional stride arguments for DmaStartOp
- add DmaStartOp::verify(), and missing test cases for DMA op's in
  test/IR/memory-ops.mlir.

PiperOrigin-RevId: 224232466

mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/StandardOps/StandardOps.h
mlir/lib/Parser/Parser.cpp
mlir/lib/StandardOps/StandardOps.cpp
mlir/test/IR/memory-ops.mlir

index beb1dc0c643a0fa6114e69a6f8e2e9badd9e6a72..3a2f633c88ca3d7d17be7ae77d92f4af3d328fe0 100644 (file)
@@ -293,6 +293,14 @@ public:
                                 int requiredOperandCount = -1,
                                 Delimiter delimiter = Delimiter::None) = 0;
 
+  /// Parse zero or more trailing SSA comma-separated trailing operand
+  /// references with a specified surrounding delimiter, and an optional
+  /// required operand count. A leading comma is expected before the operands.
+  virtual bool
+  parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
+                           int requiredOperandCount = -1,
+                           Delimiter delimiter = Delimiter::None) = 0;
+
   //===--------------------------------------------------------------------===//
   // Methods for interacting with the parser
   //===--------------------------------------------------------------------===//
index ffd903aa779e916c3d0229e0958ce118c6f359b2..562c932adbc4445d332ad94455cde7e41d892f0c 100644 (file)
@@ -316,10 +316,15 @@ private:
 // not be of the same dimensionality, but need to have the same elemental type.
 // The operands include the source and destination memref's each followed by its
 // indices, size of the data transfer in terms of the number of elements (of the
-// elemental type of the memref), and a tag memref with its indices. The tag
+// elemental type of the memref), a tag memref with its indices, and optionally
+// at the end, a stride and a number_of_elements_per_stride arguments. The tag
 // location is used by a DmaWaitOp to check for completion. The indices of the
 // source memref, destination memref, and the tag memref have the same
-// restrictions as any load/store in MLFunctions.
+// restrictions as any load/store in MLFunctions. The optional stride arguments
+// should be of 'index' type, and specify a stride for the slower memory space
+// (memory space with a lower memory space id), tranferring chunks of
+// number_of_elements_per_stride every stride until %num_elements are
+// transferred. Either both or no stride arguments should be specified.
 //
 // For example, a DmaStartOp operation that transfers 256 elements of a memref
 // '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
@@ -333,6 +338,15 @@ private:
 //     memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
 //     memref<1 x i32>, (d0) -> (d0), 2>
 //
+//   If %stride and %num_elt_per_stride are specified, the DMA is expected to
+//   transfer %num_elt_per_stride elements every %stride elements apart from
+//   memory space 0 until %num_elements are transferred.
+//
+//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
+//             %num_elt_per_stride :
+//
+// TODO(mlir-team): add additional operands to allow source and destination
+// striding, and multiple stride levels.
 // TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
 class DmaStartOp
     : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
@@ -341,7 +355,8 @@ public:
                     SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
                     SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices,
                     SSAValue *numElements, SSAValue *tagMemRef,
-                    ArrayRef<SSAValue *> tagIndices);
+                    ArrayRef<SSAValue *> tagIndices, SSAValue *stride = nullptr,
+                    SSAValue *elementsPerStride = nullptr);
 
   // Returns the source MemRefType for this DMA operation.
   const SSAValue *getSrcMemRef() const { return getOperand(0); }
@@ -388,12 +403,19 @@ public:
   const SSAValue *getTagMemRef() const {
     return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
   }
+  // Returns the rank (number of indices) of the tag MemRefType.
+  unsigned getTagMemRefRank() const {
+    return getTagMemRef()->getType().cast<MemRefType>().getRank();
+  }
+
   // Returns the tag memref index for this DMA operation.
   llvm::iterator_range<Operation::const_operand_iterator>
   getTagIndices() const {
-    return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 +
-                getDstMemRefRank() + 1 + 1,
-            getOperation()->operand_end()};
+    unsigned tagIndexStartPos =
+        1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
+    return {getOperation()->operand_begin() + tagIndexStartPos,
+            getOperation()->operand_begin() + tagIndexStartPos +
+                getTagMemRefRank()};
   }
 
   /// Returns true if this is a DMA from a faster memory space to a slower one.
@@ -418,9 +440,34 @@ public:
   static StringRef getOperationName() { return "dma_start"; }
   static bool parse(OpAsmParser *parser, OperationState *result);
   void print(OpAsmPrinter *p) const;
+  bool verify() const;
+
   static void getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context);
 
+  bool isStrided() const {
+    return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
+                                   1 + 1 + getTagMemRefRank();
+  }
+
+  SSAValue *getStride() {
+    if (!isStrided())
+      return nullptr;
+    return getOperand(getNumOperands() - 1 - 1);
+  }
+  const SSAValue *getStride() const {
+    return const_cast<DmaStartOp *>(this)->getStride();
+  }
+
+  SSAValue *getNumElementsPerStride() {
+    if (!isStrided())
+      return nullptr;
+    return getOperand(getNumOperands() - 1);
+  }
+  const SSAValue *getNumElementsPerStride() const {
+    return const_cast<DmaStartOp *>(this)->getNumElementsPerStride();
+  }
+
 protected:
   friend class ::mlir::Operation;
   explicit DmaStartOp(const Operation *state) : Op(state) {}
index a1a97b89663ed2539d2d68929b607e8300bbaa1a..26e80eefb163f8887ac05f976aa975b9f9fbe5d4 100644 (file)
@@ -2304,6 +2304,19 @@ public:
     return false;
   }
 
+  bool parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
+                                int requiredOperandCount,
+                                Delimiter delimiter) override {
+    if (parser.getToken().is(Token::comma)) {
+      parseComma();
+      return parseOperandList(result, requiredOperandCount, delimiter);
+    }
+    if (requiredOperandCount != -1)
+      return emitError(parser.getToken().getLoc(),
+                       "expected " + Twine(requiredOperandCount) + " operands");
+    return false;
+  }
+
   /// Parse a keyword followed by a type.
   bool parseKeywordType(const char *keyword, Type &result) override {
     if (parser.getTokenSpelling() != keyword)
index e46fd2319dab267c19714ccd98cd0a14420dcd05..4f837542046dfdaa449993f689aac4c8f44c48a1 100644 (file)
@@ -686,7 +686,8 @@ void DmaStartOp::build(Builder *builder, OperationState *result,
                        SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
                        SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices,
                        SSAValue *numElements, SSAValue *tagMemRef,
-                       ArrayRef<SSAValue *> tagIndices) {
+                       ArrayRef<SSAValue *> tagIndices, SSAValue *stride,
+                       SSAValue *elementsPerStride) {
   result->addOperands(srcMemRef);
   result->addOperands(srcIndices);
   result->addOperands(destMemRef);
@@ -694,6 +695,10 @@ void DmaStartOp::build(Builder *builder, OperationState *result,
   result->addOperands(numElements);
   result->addOperands(tagMemRef);
   result->addOperands(tagIndices);
+  if (stride) {
+    result->addOperands(stride);
+    result->addOperands(elementsPerStride);
+  }
 }
 
 void DmaStartOp::print(OpAsmPrinter *p) const {
@@ -705,6 +710,10 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
   *p << ", " << *getTagMemRef() << '[';
   p->printOperands(getTagIndices());
   *p << ']';
+  if (isStrided()) {
+    *p << ", " << *getStride();
+    *p << ", " << *getNumElementsPerStride();
+  }
   p->printOptionalAttrDict(getAttrs());
   *p << " : " << getSrcMemRef()->getType();
   *p << ", " << getDstMemRef()->getType();
@@ -728,6 +737,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType numElementsInfo;
   OpAsmParser::OperandType tagMemrefInfo;
   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
+  SmallVector<OpAsmParser::OperandType, 2> strideInfo;
 
   SmallVector<Type, 3> types;
   auto indexType = parser->getBuilder().getIndexType();
@@ -745,8 +755,20 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
       parser->parseComma() || parser->parseOperand(numElementsInfo) ||
       parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
       parser->parseOperandList(tagIndexInfos, -1,
-                               OpAsmParser::Delimiter::Square) ||
-      parser->parseColonTypeList(types))
+                               OpAsmParser::Delimiter::Square))
+    return true;
+
+  // Parse optional stride and elements per stride.
+  if (parser->parseTrailingOperandList(strideInfo)) {
+    return true;
+  }
+  if (!strideInfo.empty() && strideInfo.size() != 2) {
+    return parser->emitError(parser->getNameLoc(),
+                             "expected two stride related operands");
+  }
+  bool isStrided = strideInfo.size() == 2;
+
+  if (parser->parseColonTypeList(types))
     return true;
 
   if (types.size() != 3)
@@ -763,6 +785,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
       parser->resolveOperands(tagIndexInfos, indexType, result->operands))
     return true;
 
+  if (isStrided) {
+    if (parser->resolveOperand(strideInfo[0], indexType, result->operands) ||
+        parser->resolveOperand(strideInfo[1], indexType, result->operands))
+      return true;
+  }
+
   // Check that source/destination index list size matches associated rank.
   if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
       dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
@@ -776,6 +804,21 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
   return false;
 }
 
+bool DmaStartOp::verify() const {
+  // DMAs from different memory spaces supported.
+  if (getSrcMemorySpace() == getDstMemorySpace()) {
+    return emitOpError("DMA should be between different memory spaces");
+  }
+
+  if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
+                              getDstMemRefRank() + 3 + 1 &&
+      getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
+                              getDstMemRefRank() + 3 + 1 + 2) {
+    return emitOpError("incorrect number of operands");
+  }
+  return false;
+}
+
 void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                              MLIRContext *context) {
   /// dma_start(memrefcast) -> dma_start
index 7d0c238e94e40dc7bfdb8bd693b9ad66f5234542..a70c7f5dd56fd584265d6ac35b5d042dbb7ae3f8 100644 (file)
@@ -61,3 +61,29 @@ bb0:
 
   return
 }
+
+// CHECK-LABEL: mlfunc @dma_ops()
+mlfunc @dma_ops() {
+  %c0 = constant 0 : index
+  %stride = constant 32 : index
+  %elt_per_stride = constant 16 : index
+
+  %A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
+  %Ah = alloc() : memref<256 x f32, (d0) -> (d0), 1>
+  %tag = alloc() : memref<1 x f32>
+
+  %num_elements = constant 256 : index
+
+  dma_start %A[%c0], %Ah[%c0], %num_elements, %tag[%c0] : memref<256 x f32>, memref<256 x f32, 1>, memref<1 x f32>
+  dma_wait %tag[%c0], %num_elements : memref<1 x f32>
+  // CHECK: dma_start %0[%c0], %1[%c0], %c256, %2[%c0] : memref<256xf32>, memref<256xf32, 1>, memref<1xf32>
+  // CHECK-NEXT:  dma_wait %2[%c0], %c256 : memref<1xf32>
+
+  // DMA with strides
+  dma_start %A[%c0], %Ah[%c0], %num_elements, %tag[%c0], %stride, %elt_per_stride : memref<256 x f32>, memref<256 x f32, 1>, memref<1 x f32>
+  dma_wait %tag[%c0], %num_elements : memref<1 x f32>
+  // CHECK-NEXT  dma_start %0[%c0], %1[%c0], %c256, %2[%c0], %c32, %c16 : memref<256xf32>, memref<256xf32, 1>, memref<1xf32>
+  // CHECK-NEXT  dma_wait %2[%c0], %c256 : memref<1xf32>
+
+  return
+}