}
};
+/// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
+/// to dim(y) where `y` is the initial input/output value of the argument.
+///
+/// E.g.:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
+/// tensor.dim %x, %c0 : tensor<...>
+/// }
+///
+/// is folded to:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
+/// tensor.dim %y, %c0 : tensor<...>
+/// }
+template <typename OpTy>
+struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy dimOp,
+ PatternRewriter &rewriter) const final {
+ auto src = dimOp.source().template dyn_cast<BlockArgument>();
+ if (!src)
+ return failure();
+ auto loopOp = dyn_cast<TiledLoopOp>(
+ src.getOwner()->getParent()->getParentOp());
+ if (!loopOp)
+ return failure();
+
+ auto inputArgs = loopOp.getRegionInputArgs();
+ auto it1 = llvm::find(inputArgs, src);
+ if (it1 != inputArgs.end()) {
+ rewriter.updateRootInPlace(dimOp, [&] {
+ dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]);
+ });
+ return success();
+ }
+
+ auto outputArgs = loopOp.getRegionOutputArgs();
+ auto it2 = llvm::find(outputArgs, src);
+ if (it2 != outputArgs.end()) {
+ rewriter.updateRootInPlace(dimOp, [&] {
+ dimOp.sourceMutable().assign(
+ loopOp.outputs()[it2 - outputArgs.begin()]);
+ });
+ return success();
+ }
+
+ return failure();
+ }
+};
+
// Folds away TiledLoopOp output tensors when the following conditions are met:
// * result of `linalg.tiled_loop` has no uses
// * output tensor is the argument of `linalg.yield`
void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder>(context);
+ results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder,
+ DimOfTiledLoopInsOutsFolder<tensor::DimOp>,
+ DimOfTiledLoopInsOutsFolder<memref::DimOp>>(context);
}
LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
%r = tensor.dim %0, %c0 : tensor<?x?xf32>
return %r : index
}
+
+// -----
+
+// CHECK-LABEL: func @dim_of_tiled_loop_input(
+// CHECK-SAME: %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+// CHECK: %[[c0:.*]] = constant 0 : index
+// CHECK: linalg.tiled_loop
+// CHECK: %[[dim:.*]] = tensor.dim %[[arg1]], %[[c0]]
+// CHECK: index_cast %[[dim]]
+func @dim_of_tiled_loop_input(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+ to (%d0, %d1) step (%c1, %c1)
+ ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+ outs (%out1 = %arg2 : tensor<?x?xf32>) {
+ %inner_dim = tensor.dim %in1, %c0 : tensor<?x?xf32>
+ %cast1 = std.index_cast %inner_dim : index to i32
+ %cast2 = std.sitofp %cast1 : i32 to f32
+ %fill = linalg.fill(%cast2, %out1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+ linalg.yield %fill : tensor<?x?xf32>
+ }
+ return %r : tensor<?x?xf32>
+}
// TLOOP: %[[AB_SUB:.*]] = linalg.matmul
// TLOOP-SAME: ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]]
-// TLOOP: %[[DIM_B_1:.*]] = tensor.dim %[[B_]], %[[C1]] : [[TY]]
-// TLOOP: %[[DIM_C_1:.*]] = tensor.dim %[[C_]], %[[C1]] : [[TY]]
+// TLOOP: %[[DIM_B_1:.*]] = tensor.dim %[[B]], %[[C1]] : [[TY]]
+// TLOOP: %[[DIM_C_1:.*]] = tensor.dim %[[C]], %[[C1]] : [[TY]]
// TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) =
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
// TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]]
// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
-// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
// TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
// TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
// TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
// TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]]
// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
-// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
// TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
// TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
// TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]