Separating the AffineMapAccessInterface from AffineRead/WriteOp interface so that dialects which extend Affine capabilities (e.g. PlaidML PXA = parallel extensions for Affine) can utilize relevant passes (e.g. MemRef normalization).
Reviewed By: bondhugula
Differential Revision: https://reviews.llvm.org/D96284
}]
>,
InterfaceMethod<
- /*desc=*/"Returns the AffineMapAttr associated with 'memref'.",
- /*retTy=*/"NamedAttribute",
- /*methodName=*/"getAffineMapAttrForMemRef",
- /*args=*/(ins "Value":$memref),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
- assert(memref == getMemRef());
- return {Identifier::get(op.getMapAttrName(), op.getContext()),
- op.getAffineMapAttr()};
- }]
- >,
- InterfaceMethod<
/*desc=*/"Returns the value read by this operation.",
/*retTy=*/"Value",
/*methodName=*/"getValue",
}]
>,
InterfaceMethod<
- /*desc=*/"Returns the AffineMapAttr associated with 'memref'.",
- /*retTy=*/"NamedAttribute",
- /*methodName=*/"getAffineMapAttrForMemRef",
- /*args=*/(ins "Value":$memref),
+ /*desc=*/"Returns the value to store.",
+ /*retTy=*/"Value",
+ /*methodName=*/"getValueToStore",
+ /*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
- assert(memref == getMemRef());
- return {Identifier::get(op.getMapAttrName(), op.getContext()),
- op.getAffineMapAttr()};
+ return op.getOperand(op.getStoredValOperandIndex());
}]
>,
+ ];
+}
+
+def AffineMapAccessInterface : OpInterface<"AffineMapAccessInterface"> {
+ let description = [{
+ Interface to query the AffineMap used to dereference and access a given
+ memref. Implementers of this interface must operate on at least one
+ memref operand. The memref argument given to this interface much match
+ one of those memref operands.
+ }];
+
+ let methods = [
InterfaceMethod<
- /*desc=*/"Returns the value to store.",
- /*retTy=*/"Value",
- /*methodName=*/"getValueToStore",
- /*args=*/(ins),
+ /*desc=*/"Returns the AffineMapAttr associated with 'memref'.",
+ /*retTy=*/"NamedAttribute",
+ /*methodName=*/"getAffineMapAttrForMemRef",
+ /*args=*/(ins "Value":$memref),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
- return op.getOperand(op.getStoredValOperandIndex());
+ assert(memref == op.getMemRef() &&
+ "Expected memref argument to match memref operand");
+ return {Identifier::get(op.getMapAttrName(), op.getContext()),
+ op.getAffineMapAttr()};
}]
>,
];
// TODO: Consider replacing src/dst memref indices with view memrefs.
class AffineDmaStartOp
: public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
- OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+ OpTrait::VariadicOperands, OpTrait::ZeroResult,
+ AffineMapAccessInterface::Trait> {
public:
using Op::Op;
getTagMap().getNumInputs());
}
+ /// Impelements the AffineMapAccessInterface.
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
if (memref == getSrcMemRef())
//
class AffineDmaWaitOp
: public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
- OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+ OpTrait::VariadicOperands, OpTrait::ZeroResult,
+ AffineMapAccessInterface::Trait> {
public:
using Op::Op;
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
+ /// Impelements the AffineMapAccessInterface.
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
assert(memref == getTagMemRef());
class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
Affine_Op<mnemonic, !listconcat(traits,
[DeclareOpInterfaceMethods<AffineReadOpInterface>,
+ DeclareOpInterfaceMethods<AffineMapAccessInterface>,
MemRefsNormalizable])> {
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
let hasFolder = 1;
}
-def AffinePrefetchOp : Affine_Op<"prefetch"> {
+def AffinePrefetchOp : Affine_Op<"prefetch",
+ [DeclareOpInterfaceMethods<AffineMapAccessInterface>]> {
let summary = "affine prefetch operation";
let description = [{
The "affine.prefetch" op prefetches data from a memref location described
return (*this)->getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
+ /// Impelements the AffineMapAccessInterface.
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value mref) {
- assert(mref == memref());
+ assert(mref == memref() &&
+ "Expected mref argument to match memref operand");
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
Affine_Op<mnemonic, !listconcat(traits,
[DeclareOpInterfaceMethods<AffineWriteOpInterface>,
+ DeclareOpInterfaceMethods<AffineMapAccessInterface>,
MemRefsNormalizable])> {
code extraClassDeclarationBase = [{
/// Returns the operand index of the value to be stored.
return (*this)->getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
- /// Returns the AffineMapAttr associated with 'memref'.
- NamedAttribute getAffineMapAttrForMemRef(Value memref) {
- assert(memref == getMemRef());
- return {Identifier::get(getMapAttrName(), getContext()),
- getAffineMapAttr()};
- }
-
static StringRef getMapAttrName() { return "map"; }
}];
}
SmallPtrSetImpl<Operation *> &definedOps,
SmallPtrSetImpl<Operation *> &opsToHoist);
-static bool isMemRefDereferencingOp(Operation &op) {
- // TODO: Support DMA Ops.
- return isa<AffineReadOpInterface, AffineWriteOpInterface>(op);
-}
-
// Returns true if the individual op is loop invariant.
bool isOpLoopInvariant(Operation &op, Value indVar,
SmallPtrSetImpl<Operation *> &definedOps,
// which are themselves not being hoisted.
definedOps.insert(&op);
- if (isMemRefDereferencingOp(op)) {
+ if (isa<AffineMapAccessInterface>(op)) {
Value memref = isa<AffineReadOpInterface>(op)
? cast<AffineReadOpInterface>(op).getMemRef()
: cast<AffineWriteOpInterface>(op).getMemRef();
maximalFusion);
}
-// TODO: Replace when this is modeled through side-effects/op traits
-static bool isMemRefDereferencingOp(Operation &op) {
- return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
- AffineDmaWaitOp>(op);
-}
-
namespace {
// LoopNestStateCollector walks loop nests and collects load and store
return true;
// Return true if any use of 'memref' escapes the function.
for (auto *user : memref.getUsers())
- if (!isMemRefDereferencingOp(*user))
+ if (!isa<AffineMapAccessInterface>(*user))
return true;
}
return false;
// Check if 'memref' escapes through a non-affine op (e.g., std load/store,
// call op, etc.).
for (Operation *user : memref.getUsers())
- if (!isMemRefDereferencingOp(*user))
+ if (!isa<AffineMapAccessInterface>(*user))
escapingMemRefs.insert(memref);
}
}
// Interrupt the walk if found.
auto walkResult = op->walk([&](Operation *user) {
// Skip affine ops.
- if (isMemRefDereferencingOp(*user))
+ if (isa<AffineMapAccessInterface>(*user))
return WalkResult::advance();
// Find a non-affine op that uses the memref.
if (llvm::is_contained(users, user))
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
-/// Return true if this operation dereferences one or more memref's.
-// Temporary utility: will be replaced when this is modeled through
-// side-effects/op traits. TODO
-static bool isMemRefDereferencingOp(Operation &op) {
- return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
- AffineDmaWaitOp>(op);
-}
-
-/// Return the AffineMapAttr associated with memory 'op' on 'memref'.
-static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) {
- return TypeSwitch<Operation *, NamedAttribute>(op)
- .Case<AffineDmaStartOp, AffineReadOpInterface, AffinePrefetchOp,
- AffineWriteOpInterface, AffineDmaWaitOp>(
- [=](auto op) { return op.getAffineMapAttrForMemRef(memref); });
-}
-
// Perform the replacement in `op`.
LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
Operation *op,
OpBuilder builder(op);
// The following checks if op is dereferencing memref and performs the access
// index rewrites.
- if (!isMemRefDereferencingOp(*op)) {
- if (!allowNonDereferencingOps)
+ auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
+ if (!affMapAccInterface) {
+ if (!allowNonDereferencingOps) {
// Failure: memref used in a non-dereferencing context (potentially
// escapes); no replacement in these cases unless allowNonDereferencingOps
// is set.
return failure();
+ }
op->setOperand(memRefOperandPos, newMemRef);
return success();
}
// Perform index rewrites for the dereferencing op and then replace the op
- NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
+ NamedAttribute oldMapAttrPair =
+ affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
unsigned oldMapNumInputs = oldMap.getNumInputs();
SmallVector<Value, 4> oldMapOperands(
// Check if the memref was used in a non-dereferencing context. It is fine
// for the memref to be used in a non-dereferencing way outside of the
// region where this replacement is happening.
- if (!isMemRefDereferencingOp(*op)) {
+ if (!isa<AffineMapAccessInterface>(*op)) {
if (!allowNonDereferencingOps)
return failure();
// Currently we support the following non-dereferencing ops to be a