LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!genericOp.hasTensorSemantics())
- return failure();
// Check all indexing maps are identity.
if (llvm::any_of(genericOp.getIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
if (!yieldOp)
return failure();
+ // In the buffer case, we need to check exact buffer equality.
+ if (genericOp.hasBufferSemantics()) {
+ if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
+ genericOp.getInputOperand(0)->get() ==
+ genericOp.getOutputOperand(0)->get()) {
+ rewriter.eraseOp(genericOp);
+ return success();
+ }
+ return failure();
+ }
+
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
SmallVector<Value> returnedArgs;
resultType, returnedArg);
returnedArgs.push_back(returnedArg);
}
+
if (returnedArgs.size() != genericOp->getNumResults())
return failure();
rewriter.replaceOp(genericOp, returnedArgs);
%r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
return %r2 : index
}
+
+// -----
+
+// CHECK: func @fold_self_copy
+func @fold_self_copy(%0 : memref<4x16xf32>) {
+// CHECK-NEXT: return
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : memref<4x16xf32>)
+ outs(%0 : memref<4x16xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32):
+ linalg.yield %arg4 : f32
+ }
+ return
+}
ins(%arg0 : memref<?xf32>)
outs(%arg0 : memref<?xf32>) {
^bb(%0 : f32, %1 : f32) :
- linalg.yield %0 : f32
+ %2 = arith.addf %0, %0: f32
+ linalg.yield %2 : f32
}
return
}