// TODO: Consider making all this boilerplate easy to autogenerate
// with Tablegen. This seems a desirable property in the context of OpInterfaces
// where a Linalg "named" op **isa** LinalgOp.
-LogicalResult ConvOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult PoolingMaxOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult PoolingMinOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult PoolingSumOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult CopyOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult FillOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult GenericOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return verifyGenericOp<NamedStructuredOpType>(op);
}
+struct EraseDeadLinalgOp : public RewritePattern {
+ EraseDeadLinalgOp(PatternBenefit benefit = 1)
+ : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto linalgOp = dyn_cast<LinalgOp>(op);
+ if (!linalgOp)
+ return failure();
+ for (Value v : linalgOp.getInputsAndOutputBuffers()) {
+ // Linalg "inputs" may be either tensor or memref type.
+ // tensor<0xelt_type> is a convention that may not always mean
+ // "0 iterations". Only erase in cases we see memref<...x0x...>.
+ auto mt = v.getType().dyn_cast<MemRefType>();
+ if (!mt)
+ continue;
+ if (llvm::is_contained(mt.getShape(), 0)) {
+ rewriter.eraseOp(linalgOp);
+ return success();
+ }
+ }
+ return failure();
+ }
+};
+
+#define CANONICALIZERS_AND_FOLDERS(XXX) \
+ void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
+ MLIRContext *context) { \
+ results.insert<EraseDeadLinalgOp>(); \
+ } \
+ \
+ LogicalResult XXX::fold(ArrayRef<Attribute>, \
+ SmallVectorImpl<OpFoldResult> &) { \
+ return foldMemRefCast(*this); \
+ }
+
+CANONICALIZERS_AND_FOLDERS(ConvOp);
+CANONICALIZERS_AND_FOLDERS(PoolingMaxOp);
+CANONICALIZERS_AND_FOLDERS(PoolingMinOp);
+CANONICALIZERS_AND_FOLDERS(PoolingSumOp);
+CANONICALIZERS_AND_FOLDERS(CopyOp);
+CANONICALIZERS_AND_FOLDERS(FillOp);
+CANONICALIZERS_AND_FOLDERS(GenericOp);
+CANONICALIZERS_AND_FOLDERS(IndexedGenericOp);
+
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
// TODO: Determine whether we can generate the folders and verifiers.
-LogicalResult BatchMatmulOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult DotOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult ConvWOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult ConvNWCOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult ConvNCWOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult ConvHWOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult ConvNHWCOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult ConvNCHWOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult ConvDHWOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult ConvNDHWCOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
-LogicalResult ConvNCDHWOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
+CANONICALIZERS_AND_FOLDERS(BatchMatmulOp);
+CANONICALIZERS_AND_FOLDERS(DotOp);
+CANONICALIZERS_AND_FOLDERS(MatmulOp);
+CANONICALIZERS_AND_FOLDERS(MatvecOp);
+CANONICALIZERS_AND_FOLDERS(ConvWOp);
+CANONICALIZERS_AND_FOLDERS(ConvNWCOp);
+CANONICALIZERS_AND_FOLDERS(ConvNCWOp);
+CANONICALIZERS_AND_FOLDERS(ConvHWOp);
+CANONICALIZERS_AND_FOLDERS(ConvNHWCOp);
+CANONICALIZERS_AND_FOLDERS(ConvNCHWOp);
+CANONICALIZERS_AND_FOLDERS(ConvDHWOp);
+CANONICALIZERS_AND_FOLDERS(ConvNDHWCOp);
+CANONICALIZERS_AND_FOLDERS(ConvNCDHWOp);