[MLIR] Guard DMA-specific logic with DMA option
authorTim Shen <timshen@google.com>
Wed, 26 Feb 2020 04:16:08 +0000 (20:16 -0800)
committerTim Shen <timshen@google.com>
Wed, 11 Mar 2020 18:23:13 +0000 (11:23 -0700)
Differential Revision: https://reviews.llvm.org/D75963

mlir/lib/Transforms/Utils/LoopUtils.cpp

index c2cd233..1c9ac5e 100644 (file)
@@ -1411,22 +1411,24 @@ static LogicalResult generateCopy(
   auto numElementsSSA =
       top.create<ConstantIndexOp>(loc, numElements.getValue());
 
-  SmallVector<StrideInfo, 4> strideInfos;
-  getMultiLevelStrides(region, fastBufferShape, &strideInfos);
-
-  // TODO(bondhugula): use all stride levels once DmaStartOp is extended for
-  // multi-level strides.
-  if (strideInfos.size() > 1) {
-    LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
-    return failure();
-  }
+  Value dmaStride = nullptr;
+  Value numEltPerDmaStride = nullptr;
+  if (copyOptions.generateDma) {
+    SmallVector<StrideInfo, 4> dmaStrideInfos;
+    getMultiLevelStrides(region, fastBufferShape, &dmaStrideInfos);
+
+    // TODO(bondhugula): use all stride levels once DmaStartOp is extended for
+    // multi-level strides.
+    if (dmaStrideInfos.size() > 1) {
+      LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
+      return failure();
+    }
 
-  Value stride = nullptr;
-  Value numEltPerStride = nullptr;
-  if (!strideInfos.empty()) {
-    stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
-    numEltPerStride =
-        top.create<ConstantIndexOp>(loc, strideInfos[0].numEltPerStride);
+    if (!dmaStrideInfos.empty()) {
+      dmaStride = top.create<ConstantIndexOp>(loc, dmaStrideInfos[0].stride);
+      numEltPerDmaStride =
+          top.create<ConstantIndexOp>(loc, dmaStrideInfos[0].numEltPerStride);
+    }
   }
 
   // Record the last operation where we want the memref replacement to end. We
@@ -1469,13 +1471,13 @@ static LogicalResult generateCopy(
       b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
                                  fastMemRef, bufAffineMap, bufIndices,
                                  tagMemRef, tagAffineMap, tagIndices,
-                                 numElementsSSA, stride, numEltPerStride);
+                                 numElementsSSA, dmaStride, numEltPerDmaStride);
     } else {
       // DMA non-blocking write from fast buffer to the original memref.
       auto op = b.create<AffineDmaStartOp>(
           loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
           memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
-          stride, numEltPerStride);
+          dmaStride, numEltPerDmaStride);
       // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
       // end to mark end of block range being processed.
       if (isCopyOutAtEndOfBlock)