From 834133c950fce120d0378d09718d32a320cbcd72 Mon Sep 17 00:00:00 2001 From: Anand Kodnani Date: Tue, 28 Jul 2020 10:37:16 -0700 Subject: [PATCH] [MLIR] Vector store to load forwarding The MemRefDataFlow pass does store to load forwarding only for affine store/loads. This patch updates the pass to use affine read/write interface which enables vector forwarding. Reviewed By: dcaballe, bondhugula, ftynse Differential Revision: https://reviews.llvm.org/D84302 --- .../Dialect/Affine/IR/AffineMemoryOpInterfaces.td | 21 +++++++++++++++++++++ mlir/include/mlir/Dialect/Affine/IR/AffineOps.td | 4 ++-- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 15 ++++++++------- mlir/test/Transforms/memref-dataflow-opt.mlir | 20 ++++++++++++++++++++ 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td index a093cb9..1f25073 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -81,6 +81,16 @@ def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> { op.getAffineMapAttr()}; }] >, + InterfaceMethod< + /*desc=*/"Returns the value read by this operation.", + /*retTy=*/"Value", + /*methodName=*/"getValue", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return cast(this->getOperation()); + }] + >, ]; } @@ -150,6 +160,17 @@ def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> { op.getAffineMapAttr()}; }] >, + InterfaceMethod< + /*desc=*/"Returns the value to store.", + /*retTy=*/"Value", + /*methodName=*/"getValueToStore", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ConcreteOp op = cast(this->getOperation()); + return op.getOperand(op.getStoredValOperandIndex()); + }] + >, ]; } diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index ab7d96f..95e17aa 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -725,8 +725,8 @@ class AffineStoreOpBase traits = []> : Affine_Op])> { code extraClassDeclarationBase = [{ - /// Get value to be stored by store operation. - Value getValueToStore() { return getOperand(0); } + /// Returns the operand index of the value to be stored. + unsigned getStoredValOperandIndex() { return 0; } /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 1; } diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 7220fd1..7924b46 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -63,7 +63,7 @@ namespace { struct MemRefDataFlowOpt : public MemRefDataFlowOptBase { void runOnFunction() override; - void forwardStoreToLoad(AffineLoadOp loadOp); + void forwardStoreToLoad(AffineReadOpInterface loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; @@ -84,14 +84,14 @@ std::unique_ptr> mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // if needed. -void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { +void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) { // First pass over the use list to get the minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(loadOp); for (auto *user : loadOp.getMemRef().getUsers()) { - auto storeOp = dyn_cast(user); + auto storeOp = dyn_cast(user); if (!storeOp) continue; unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp); @@ -167,8 +167,9 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { return; // Perform the actual store to load forwarding. - Value storeVal = cast(lastWriteStoreOp).getValueToStore(); - loadOp.replaceAllUsesWith(storeVal); + Value storeVal = + cast(lastWriteStoreOp).getValueToStore(); + loadOp.getValue().replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); // Record this to erase later. @@ -190,7 +191,7 @@ void MemRefDataFlowOpt::runOnFunction() { memrefsToErase.clear(); // Walk all load's and perform store to load forwarding. - f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); }); + f.walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) @@ -207,7 +208,7 @@ void MemRefDataFlowOpt::runOnFunction() { // could still erase it if the call had no side-effects. continue; if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { - return !isa(ownerOp); + return !isa(ownerOp); })) continue; diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir index 6d5288c..dfda193 100644 --- a/mlir/test/Transforms/memref-dataflow-opt.mlir +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -280,3 +280,23 @@ func @refs_not_known_to_be_equal(%A : memref<100 x 100 x f32>, %M : index) { } return } + +// The test checks for value forwarding from vector stores to vector loads. +// The value loaded from %in can directly be stored to %out by eliminating +// store and load from %tmp. +func @vector_forwarding(%in : memref<512xf32>, %out : memref<512xf32>) { + %tmp = alloc() : memref<512xf32> + affine.for %i = 0 to 16 { + %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + affine.vector_store %ld0, %tmp[32*%i] : memref<512xf32>, vector<32xf32> + %ld1 = affine.vector_load %tmp[32*%i] : memref<512xf32>, vector<32xf32> + affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32> + } + return +} + +// CHECK-LABEL: func @vector_forwarding +// CHECK: affine.for %{{.*}} = 0 to 16 { +// CHECK-NEXT: %[[LDVAL:.*]] = affine.vector_load +// CHECK-NEXT: affine.vector_store %[[LDVAL]],{{.*}} +// CHECK-NEXT: } -- 2.7.4