return success();
}
};
+
+/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
+/// with the corresponding output tensor argument of the linalg op.
+struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ Value dimOpArg = dimOp.memrefOrTensor();
+ auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>();
+ if (!linalgOp)
+ return failure();
+
+ auto results = linalgOp.getOperation()->getResults();
+ int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg));
+ auto outputTensors = linalgOp.getOutputTensors();
+ rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index());
+ return success();
+ }
+};
} // namespace
#define CANONICALIZERS_AND_FOLDERS(XXX) \
void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
MLIRContext *context) { \
- results.insert<EraseDeadLinalgOp>(); \
- results.insert<FoldTensorCastOp>(); \
- results.insert<DeduplicateInputs>(); \
+ results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp>(); \
+ results.insert<ReplaceDimOfLinalgResult>(context); \
} \
\
LogicalResult XXX::fold(ArrayRef<Attribute>, \
// CHECK: func @init_tensor_dynamic_dim
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG0]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
+ %arg_1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %0, %1 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel"]
+ } ins(%arg_0 : tensor<?xf32>)
+ outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%in: f32, %out_0: f32, %out_1: f32):
+ linalg.yield %in, %in : f32, f32
+ } -> tensor<?xf32>, tensor<?xf32>
+
+ %c0 = constant 0 : index
+ %num_elem_0 = dim %0, %c0 : tensor<?xf32>
+ %result_0 = linalg.init_tensor [%num_elem_0] : tensor<?xf32>
+
+ %num_elem_1 = dim %1, %c0 : tensor<?xf32>
+ %result_1 = linalg.init_tensor [%num_elem_1] : tensor<?xf32>
+ return %result_0, %result_1 : tensor<?xf32>, tensor<?xf32>
+}
+// CHECK-LABEL: func @init_tensor_dim_of_linalg_result(
+// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
+// CHECK: dim [[ARG_0]]
+// CHECK: dim [[ARG_1]]