From: Hanhan Wang Date: Fri, 4 Dec 2020 07:10:20 +0000 (-0800) Subject: [mlir][Linalg] Handle fusion on tensors for projected permutation. X-Git-Tag: llvmorg-13-init~4398 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f5f1a5c2448e31f3c7e6f85b378372a02f8d3e43;p=platform%2Fupstream%2Fllvm.git [mlir][Linalg] Handle fusion on tensors for projected permutation. In the past, the reshape op can be folded only if the indexing map is permutation in consumer's usage. We can relax to condition to be projected permutation. This patch still limits the fusion for scalar cases. Scalar case is a corner case, because we need to decide where to put extra dims. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D92466 --- diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index fb916d3..3df609f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -118,11 +118,12 @@ Optional> fuseTensorOps(PatternRewriter &rewriter, /// dimension is statically known, or -1 otherwise. SmallVector getStaticShape(LinalgOp linalgOp); -/// Returns the statically-known loop ranges of the `linalgOp`. Applies the -/// inverse of the concatenated indexing maps to the result of `getStaticShape`. -/// Returns None if inverting the concatenated indexing map fails. Returns -1 +/// Returns the statically-known loop ranges of the `linalgOp`. Composes +/// `linalgOp.getShapesToLoopsMap()` with the result of `getStaticShape`. +/// Returns None if `linalgOp.getShapesToLoopsMap()` fails. Returns -1 /// for non-statically-known loop ranges. Optional> getStaticLoopRanges(LinalgOp linalgOp); + /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index fea80fa..22e03c1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -411,21 +411,19 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, unsigned fusedTensorIndex) { // Is fusable only if: // - The linalgOp is a generic op, or an indexed_generic. - // - All the indexing maps for operands in linalgOp are projected + // - All the indexing maps for operands and results in linalgOp are projected // permutations. - // - The indexing map at the position representing the fused tensor is a - // permutation. + // - The fused tensor is not a scalar. // - All the loops in linalgOp are parallel loops. return isa(linalgOp.getOperation()) && linalgOp.hasTensorSemantics() && - llvm::all_of(linalgOp.indexing_maps().getValue().take_front( - linalgOp.getNumInputs()), + llvm::all_of(linalgOp.indexing_maps().getValue(), [](Attribute attr) { return attr.cast() .getValue() .isProjectedPermutation(); }) && - linalgOp.getIndexingMap(fusedTensorIndex).isPermutation() && + linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 && llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) { return attr.cast().getValue() == getParallelIteratorTypeName(); @@ -446,8 +444,6 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); RankedTensorType expandedType = isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); - RankedTensorType foldedType = - isExpanding ? reshapeOp.getSrcType() : reshapeOp.getResultType(); AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); // The reshape is folding/expanding consecutive dimensions. Given the indexing @@ -455,9 +451,15 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, // the original op is expanded into. Also record the shape of the expanded // dimensions. ArrayRef expandedShape = expandedType.getShape(); - SmallVector numFoldedDims(foldedType.getRank(), 0); + Optional> origOpLoopRange = + getStaticLoopRanges(linalgOp); + if (!origOpLoopRange) { + linalgOp.emitError("unable to find loop range for operation"); + return llvm::None; + } + SmallVector numFoldedDims(fusedIndexMap.getNumDims(), 1); SmallVector, 4> expandedDimsShape( - foldedType.getRank()); + fusedIndexMap.getNumDims()); auto reassociationMaps = reshapeOp.getReassociationMaps(); for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { unsigned pos = resultExpr.value().cast().getPosition(); @@ -467,6 +469,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]); expandedDimsShape[pos].assign(shape.begin(), shape.end()); } + // The remaining dimensions remain the same. + for (unsigned i : llvm::seq(0, fusedIndexMap.getNumDims())) + if (expandedDimsShape[i].empty()) + expandedDimsShape[i] = {(*origOpLoopRange)[i]}; if (isa(linalgOp.getOperation())) { // For indexed generic op, the region contains arguments that represent the @@ -476,6 +482,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, // front) are statically know. For dynamic case, we would need shape // information on these dimensions to get these. for (auto &expandedShape : expandedDimsShape) { + if (expandedShape.size() == 1) + continue; for (int64_t expandedDimShape : llvm::make_range( std::next(expandedShape.begin()), expandedShape.end())) { if (ShapedType::isDynamic(expandedDimShape)) { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 43f4016..8e60312 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -104,13 +104,18 @@ SmallVector getStaticShape(LinalgOp linalgOp) { auto shape = v.getType().cast().getShape(); res.append(shape.begin(), shape.end()); } + if (linalgOp.getNumInitTensors()) + return res; + for (Value v : linalgOp.getOperation()->getResults()) { + auto shape = v.getType().cast().getShape(); + res.append(shape.begin(), shape.end()); + } return res; } Optional> getStaticLoopRanges(LinalgOp linalgOp) { SmallVector viewSizes = getStaticShape(linalgOp); - AffineMap invertedMap = - inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps())); + AffineMap invertedMap = linalgOp.getShapesToLoopsMap(); if (!invertedMap) return {}; return invertedMap.compose(viewSizes); diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 1f201f7..66e07cc 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -344,3 +344,97 @@ func @reshape_as_consumer_permutation // CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]] // CHECK: %[[T10:.+]] = index_cast %[[ARG7]] // CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]] + +// ----- + +func @reshape_as_producer_projected_permutation + (%arg0 : tensor<33x8x?xi32>) -> tensor<264x?x4xi32> { + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] + : tensor<33x8x?xi32> into tensor<264x?xi32> + %1 = linalg.indexed_generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<264x?xi32>) { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg1 : index to i32 + %3 = addi %arg4, %2 : i32 + %4 = index_cast %arg2 : index to i32 + %5 = addi %3, %4 : i32 + %6 = index_cast %arg3 : index to i32 + %7 = addi %5, %6 : i32 + linalg.yield %7 : i32 + } -> tensor<264x?x4xi32> + return %1 : tensor<264x?x4xi32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: @reshape_as_producer_projected_permutation +// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32> +// CHECK: %[[RES:.+]] = linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]] : tensor<33x8x?xi32>) +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32) +// CHECK: %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG1]], %[[ARG2]]) +// CHECK: %[[T1:.+]] = index_cast %[[T0]] : index to i32 +// CHECK: %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32 +// CHECK: %[[T3:.+]] = index_cast %[[ARG3]] : index to i32 +// CHECK: %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32 +// CHECK: %[[T5:.+]] = index_cast %[[ARG4]] : index to i32 +// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32 +// CHECK: linalg.yield %[[T6]] : i32 +// CHECK: %[[RES2:.+]] = linalg.tensor_reshape %[[RES]] +// CHECK-SAME: [#[[MAP3]], #[[MAP4]], #[[MAP5]]] +// CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32> +// CHECK: return %[[RES2]] : tensor<264x?x4xi32> + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> +func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> +// CHECK: func @generic_op_reshape_consumer_fusion_projected +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[T2:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) +// CHECK: return %[[T2]] : tensor