From 06c02d5dbb13f6d2a10eaa75c236f3c61cdf5b91 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 18 Aug 2022 22:39:29 +0000 Subject: [PATCH] [mlir][linalg] Fix tiling interface implementation offset calculation The tiling interface implementation was making assumption on the code generated by makeTiledShape which were wrong. The ExtractSliceOp create may be combined with other ExtractSliceOp. To solve that we compute directly the offset using the new utilities. Differential Revision: https://reviews.llvm.org/D132182 --- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 37 +++++++--------- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 15 +++---- .../test/Dialect/Linalg/multisize-tiling-full.mlir | 2 +- .../Dialect/Linalg/tile-to-foreach-thread.mlir | 50 ++++++++++++++++------ 4 files changed, 60 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 417e727..e5040a4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -177,18 +177,6 @@ mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op, return spec; } -/// Given a `subsetExtractOp`, a `source` and a `dest`, create a new -/// `ParallelInsertSlice` op of `source` into `dest` at the same subset location -/// as `subsetExtractOp`. -static void -createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc, - tensor::ExtractSliceOp subsetExtractOp, - Value source, Value dest) { - b.create( - loc, source, dest, subsetExtractOp.getMixedOffsets(), - subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides()); -} - /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less /// than `iterationSize`. static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, @@ -333,16 +321,21 @@ static FailureOr tileToForeachThreadOpImpl( auto tilingInterfaceOp = dyn_cast(tiledOp); assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface"); - - auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b); - - // Create terminator with parallel subset insert operations. - b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody()); - for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(), - destOperands)) { - createMatchingParallelSubsetInsertOp( - b, loc, cast(std::get<0>(it).getDefiningOp()), - std::get<1>(it), std::get<2>(it)); + OpBuilder::InsertPoint insertPt = b.saveInsertionPoint(); + for (auto it : + llvm::zip(llvm::seq(unsigned(0), unsigned(destOperands.size())), + tilingInterfaceOp->getResults(), destOperands)) { + b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint()); + SmallVector resultOffsets, resultSizes; + if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets, + tiledSizes, resultOffsets, + resultSizes))) + return op->emitOpError("output offsets couldn't be calculated"); + SmallVector strides(resultSizes.size(), b.getIndexAttr(1)); + b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody()); + b.create(loc, std::get<1>(it), + std::get<2>(it), resultOffsets, + resultSizes, strides); } return ForeachThreadTilingResult{foreachThreadOp, tiledOp}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 67d5e99..cfd3c12 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -161,15 +161,12 @@ struct LinalgOpTilingInterface })); OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); - Value sliceOpResult = - makeTiledShape(b, loc, outOperand->get(), sizes, - linalgOp.getTiedIndexingMap(outOperand), offsets, - /*ubs*/ {}, subShapeSizes, true); - auto sliceOp = sliceOpResult.getDefiningOp(); - if (!sliceOp) - return failure(); - resultOffsets = sliceOp.getMixedOffsets(); - resultSizes = sliceOp.getMixedSizes(); + SliceParameters sliceParams = + computeSliceParameters(b, loc, outOperand->get(), sizes, + linalgOp.getTiedIndexingMap(outOperand), offsets, + /*ubs*/ {}, subShapeSizes, true); + resultOffsets = sliceParams.offsets; + resultSizes = sliceParams.sizes; return success(); } diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir index cfbac63..ecafcad 100644 --- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir +++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir @@ -59,7 +59,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>, // CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]] // CHECK: scf.yield %[[RESPARTIAL]] - // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][%[[I1]], 0] [2, 16] [1, 1] + // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] // CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1] // CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]]) // CHECK-COUNT-2: tensor.extract_slice diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir index ab63ac9..dd06290 100644 --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -cse -split-input-file | FileCheck %s // Offset per thread: // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))> @@ -22,7 +22,7 @@ module { // CHECK: %[[RES:.*]] = linalg.matmul // CHECK-SAME: ins(%[[tA]], %[[tB]] : tensor, tensor) // CHECK-SAME: outs(%[[tC]] : tensor) -> tensor - // CHECK-NEXT: scf.foreach_thread.perform_concurrently { + // CHECK: scf.foreach_thread.perform_concurrently { // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} : // CHECK-SAME: tensor into tensor // CHECK-NEXT: } @@ -65,11 +65,9 @@ func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: t // CHECK-NOT: affine.max // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]]) // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) - // CHECK: %[[LB0_1:.+]] = affine.apply #[[$map2]](%[[IV0]]) - // CHECK: %[[LB1_1:.+]] = affine.apply #[[$map3]](%[[IV1]]) // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] : // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] : - // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0_1]], %[[LB1_1]]] [10, %[[TS]]] [1, 1] : + // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently // CHECK-NEXT: tensor.parallel_insert_slice @@ -106,8 +104,6 @@ func.func @matmul_tile_size_dynamic(%A: tensor, %B: tensor, %C // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : // CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]] // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]] - // CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 : - // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]] // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]] @@ -115,8 +111,6 @@ func.func @matmul_tile_size_dynamic(%A: tensor, %B: tensor, %C // CHECK tensor.extract_slice %[[A]] // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]]) // CHECK tensor.extract_slice %[[B]] - // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]]) - // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]]) // CHECK tensor.extract_slice %[[C]] // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently @@ -156,11 +150,9 @@ func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf // CHECK-NOT: affine.min // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]]) // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) - // CHECK: %[[LB0_1:.+]] = affine.apply #[[$map2]](%[[IV0]]) - // CHECK: %[[LB1_1:.+]] = affine.apply #[[$map3]](%[[IV1]]) // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] : // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] : - // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0_1]], %[[LB1_1]]] [10, %[[TS]]] [1, 1] : + // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently // CHECK-NEXT: tensor.parallel_insert_slice @@ -177,3 +169,37 @@ transform.with_pdl_patterns { %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 21] } } + +// ----- + +module { + func.func @extract_source(%A: tensor<4xf32>, %B: tensor<16xf32>) -> tensor<4xf32> { + %B1 = tensor.extract_slice %B[10] [4] [1] : tensor<16xf32> to tensor<4xf32> + %result = linalg.generic {indexing_maps = [ + affine_map<(d0) -> (d0)>,affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%A : tensor<4xf32>) outs(%B1 : tensor<4xf32>) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %2 = arith.addf %arg3, %arg3 : f32 + linalg.yield %2 : f32 + } -> tensor<4xf32> + return %result : tensor<4xf32> + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0]) + } + } +} +// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)> + +// CHECK-LABEL: extract_source( +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) -> (tensor<4xf32>) { +// CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]]) +// CHECK: scf.foreach_thread.perform_concurrently { +// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32> -- 2.7.4