[MLIR] Vector store to load forwarding
authorAnand Kodnani <anand.kodnani@intel.com>
Tue, 28 Jul 2020 17:37:16 +0000 (10:37 -0700)
committerDiego Caballero <diego.caballero@intel.com>
Tue, 28 Jul 2020 18:30:54 +0000 (11:30 -0700)
The MemRefDataFlow pass does store to load forwarding
only for affine store/loads. This patch updates the pass
to use affine read/write interface which enables vector
forwarding.

Reviewed By: dcaballe, bondhugula, ftynse

Differential Revision: https://reviews.llvm.org/D84302

mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/lib/Transforms/MemRefDataFlowOpt.cpp
mlir/test/Transforms/memref-dataflow-opt.mlir

index a093cb9..1f25073 100644 (file)
@@ -81,6 +81,16 @@ def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> {
                 op.getAffineMapAttr()};
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/"Returns the value read by this operation.",
+      /*retTy=*/"Value",
+      /*methodName=*/"getValue",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return cast<ConcreteOp>(this->getOperation());
+      }]
+    >,
   ];
 }
 
@@ -150,6 +160,17 @@ def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> {
                 op.getAffineMapAttr()};
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/"Returns the value to store.",
+      /*retTy=*/"Value",
+      /*methodName=*/"getValueToStore",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return op.getOperand(op.getStoredValOperandIndex());
+      }]
+    >,
   ];
 }
 
index ab7d96f..95e17aa 100644 (file)
@@ -725,8 +725,8 @@ class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
     Affine_Op<mnemonic, !listconcat(traits,
         [DeclareOpInterfaceMethods<AffineWriteOpInterface>])> {
   code extraClassDeclarationBase = [{
-    /// Get value to be stored by store operation.
-    Value getValueToStore() { return getOperand(0); }
+    /// Returns the operand index of the value to be stored.
+    unsigned getStoredValOperandIndex() { return 0; }
 
     /// Returns the operand index of the memref.
     unsigned getMemRefOperandIndex() { return 1; }
index 7220fd1..7924b46 100644 (file)
@@ -63,7 +63,7 @@ namespace {
 struct MemRefDataFlowOpt : public MemRefDataFlowOptBase<MemRefDataFlowOpt> {
   void runOnFunction() override;
 
-  void forwardStoreToLoad(AffineLoadOp loadOp);
+  void forwardStoreToLoad(AffineReadOpInterface loadOp);
 
   // A list of memref's that are potentially dead / could be eliminated.
   SmallPtrSet<Value, 4> memrefsToErase;
@@ -84,14 +84,14 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createMemRefDataFlowOptPass() {
 
 // This is a straightforward implementation not optimized for speed. Optimize
 // if needed.
-void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
+void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
   // First pass over the use list to get the minimum number of surrounding
   // loops common between the load op and the store op, with min taken across
   // all store ops.
   SmallVector<Operation *, 8> storeOps;
   unsigned minSurroundingLoops = getNestingDepth(loadOp);
   for (auto *user : loadOp.getMemRef().getUsers()) {
-    auto storeOp = dyn_cast<AffineStoreOp>(user);
+    auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
     if (!storeOp)
       continue;
     unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
@@ -167,8 +167,9 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
     return;
 
   // Perform the actual store to load forwarding.
-  Value storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore();
-  loadOp.replaceAllUsesWith(storeVal);
+  Value storeVal =
+    cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
+  loadOp.getValue().replaceAllUsesWith(storeVal);
   // Record the memref for a later sweep to optimize away.
   memrefsToErase.insert(loadOp.getMemRef());
   // Record this to erase later.
@@ -190,7 +191,7 @@ void MemRefDataFlowOpt::runOnFunction() {
   memrefsToErase.clear();
 
   // Walk all load's and perform store to load forwarding.
-  f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); });
+  f.walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp); });
 
   // Erase all load op's whose results were replaced with store fwd'ed ones.
   for (auto *loadOp : loadOpsToErase)
@@ -207,7 +208,7 @@ void MemRefDataFlowOpt::runOnFunction() {
       // could still erase it if the call had no side-effects.
       continue;
     if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
-          return !isa<AffineStoreOp, DeallocOp>(ownerOp);
+          return !isa<AffineWriteOpInterface, DeallocOp>(ownerOp);
         }))
       continue;
 
index 6d5288c..dfda193 100644 (file)
@@ -280,3 +280,23 @@ func @refs_not_known_to_be_equal(%A : memref<100 x 100 x f32>, %M : index) {
   }
   return
 }
+
+// The test checks for value forwarding from vector stores to vector loads.
+// The value loaded from %in can directly be stored to %out by eliminating
+// store and load from %tmp.
+func @vector_forwarding(%in : memref<512xf32>, %out : memref<512xf32>) {
+  %tmp = alloc() : memref<512xf32>
+  affine.for %i = 0 to 16 {
+    %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+    affine.vector_store %ld0, %tmp[32*%i] : memref<512xf32>, vector<32xf32>
+    %ld1 = affine.vector_load %tmp[32*%i] : memref<512xf32>, vector<32xf32>
+    affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @vector_forwarding
+// CHECK:      affine.for %{{.*}} = 0 to 16 {
+// CHECK-NEXT:   %[[LDVAL:.*]] = affine.vector_load
+// CHECK-NEXT:   affine.vector_store %[[LDVAL]],{{.*}}
+// CHECK-NEXT: }