From 56796ae1a8db4c85dada28676f8303a5a3609c63 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 15 Jul 2022 13:43:57 -0400 Subject: [PATCH] [mlir][linalg] Fix tensor tiling together with interchange In `linalg::tileConsumerAndFuseProducers`, there are two levels of tiling and fusion; we partition the tile sizes and only use one half for each of them. The partition is using the first non-parallel dimension *after* interchange as the boundary. However, concrete tiling happens *together with* loop interchange, so we still need to provide the partial tile sizes *before* the interchange. Otherwise, there will be inconsistency, which is what this patch is to fix. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D129804 --- .../Dialect/Linalg/Transforms/FusionOnTensors.cpp | 17 +++++--- mlir/test/Dialect/Linalg/transform-op-fuse.mlir | 51 ++++++++++++++++++++++ 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index d968b37..66a558c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -454,19 +454,24 @@ FailureOr mlir::linalg::tileConsumerAndFuseProducers( } }; + // Perform tiling and fusion in two steps. We need to respect the loop + // interchange here; filter parellel dimensions based on their order *after* + // permutation but pass in the original configuration *before* permuation, + // given the tiling and interchange happen together. + SmallVector outerTileSizes(tileSizes.size(), 0); + SmallVector innerTileSizes(tileSizes.size(), 0); + for (int64_t i : tileInterchange.take_front(split)) + outerTileSizes[i] = tileSizes[i]; + for (int64_t i : tileInterchange.drop_front(split)) + innerTileSizes[i] = tileSizes[i]; + // Tile the outer parallel loops and fuse the output operands. - SmallVector outerTileSizes; - outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); - outerTileSizes.append(tileSizes.size() - split, 0); if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange, tileDistribution))) return failure(); fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); // Tile the remaining loops and fuse the input operands. - SmallVector innerTileSizes; - innerTileSizes.append(split, 0); - innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange, tileDistribution))) return failure(); diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index af6da5d..1d4d620 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -68,3 +68,54 @@ transform.with_pdl_patterns { transform.loop.peel %loops#0 } } + +// ----- + +// CHECK-LABEL: func.func @interchange_reduction +// CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>) +func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> { + %five = arith.constant 5.0 : f32 + %init = linalg.init_tensor [12, 25] : tensor<12x25xf32> + +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index + +// CHECK: %[[INIT:.+]] = linalg.init_tensor [12, 25] +// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]]) +// CHECK: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C7]] iter_args(%[[FOR_ARG1:.+]] = %[[FOR_ARG0]]) +// CHECK: %[[OUT_SLICE0:.+]] = tensor.extract_slice %[[FOR_ARG1]][%[[IV0]], %[[IV1]]] +// CHECK: %[[FILL:.+]] = linalg.fill {{.+}} outs(%[[OUT_SLICE0]] : tensor) +// CHECK: scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %[[C4]] iter_args(%[[FOR_ARG2:.+]] = %[[FILL]]) +// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[INPUT]] +// CHECK: %[[OUT_SLICE2:.+]] = tensor.extract_slice %[[FOR_ARG2]][0, 0] +// CHECK: linalg.generic {{.+}} ins(%[[IN_SLICE]] : tensor) outs(%[[OUT_SLICE2]] : tensor) + + %fill = linalg.fill ins(%five : f32) outs(%init : tensor<12x25xf32>) -> tensor<12x25xf32> + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%input : tensor<12x7x25xf32>) outs(%fill : tensor<12x25xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + %2 = arith.addf %arg0, %arg1 : f32 + linalg.yield %2 : f32 + } -> tensor<12x25xf32> + func.return %0 : tensor<12x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.generic"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [5, 4, 7], tile_interchange = [0, 2, 1]} + } +} -- 2.7.4