[mlir][Linalg] Handle fusion on tensors for projected permutation.
authorHanhan Wang <hanchung@google.com>
Fri, 4 Dec 2020 07:10:20 +0000 (23:10 -0800)
committerHanhan Wang <hanchung@google.com>
Fri, 4 Dec 2020 07:11:29 +0000 (23:11 -0800)
In the past, the reshape op can be folded only if the indexing map is
permutation in consumer's usage. We can relax to condition to be projected
permutation.

This patch still limits the fusion for scalar cases. Scalar case is a corner
case, because we need to decide where to put extra dims.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D92466

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/reshape_fusion.mlir

index fb916d3..3df609f 100644 (file)
@@ -118,11 +118,12 @@ Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
 /// dimension is statically known, or -1 otherwise.
 SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
 
-/// Returns the statically-known loop ranges of the `linalgOp`. Applies the
-/// inverse of the concatenated indexing maps to the result of `getStaticShape`.
-/// Returns None if inverting the concatenated indexing map fails. Returns -1
+/// Returns the statically-known loop ranges of the `linalgOp`. Composes
+/// `linalgOp.getShapesToLoopsMap()` with the result of `getStaticShape`.
+/// Returns None if `linalgOp.getShapesToLoopsMap()` fails. Returns -1
 /// for non-statically-known loop ranges.
 Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
+
 /// Apply the permutation defined by `permutation` to `inVec`.
 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
 /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
index fea80fa..22e03c1 100644 (file)
@@ -411,21 +411,19 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
                                                unsigned fusedTensorIndex) {
   // Is fusable only if:
   // - The linalgOp is a generic op, or an indexed_generic.
-  // - All the indexing maps for operands in linalgOp are projected
+  // - All the indexing maps for operands and results in linalgOp are projected
   //   permutations.
-  // - The indexing map at the position representing the fused tensor is a
-  //   permutation.
+  // - The fused tensor is not a scalar.
   // - All the loops in linalgOp are parallel loops.
   return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) &&
          linalgOp.hasTensorSemantics() &&
-         llvm::all_of(linalgOp.indexing_maps().getValue().take_front(
-                          linalgOp.getNumInputs()),
+         llvm::all_of(linalgOp.indexing_maps().getValue(),
                       [](Attribute attr) {
                         return attr.cast<AffineMapAttr>()
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         linalgOp.getIndexingMap(fusedTensorIndex).isPermutation() &&
+         linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
          llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) {
            return attr.cast<StringAttr>().getValue() ==
                   getParallelIteratorTypeName();
@@ -446,8 +444,6 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
       reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
   RankedTensorType expandedType =
       isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
-  RankedTensorType foldedType =
-      isExpanding ? reshapeOp.getSrcType() : reshapeOp.getResultType();
   AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
 
   // The reshape is folding/expanding consecutive dimensions. Given the indexing
@@ -455,9 +451,15 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
   // the original op is expanded into. Also record the shape of the expanded
   // dimensions.
   ArrayRef<int64_t> expandedShape = expandedType.getShape();
-  SmallVector<unsigned, 4> numFoldedDims(foldedType.getRank(), 0);
+  Optional<SmallVector<int64_t, 4>> origOpLoopRange =
+      getStaticLoopRanges(linalgOp);
+  if (!origOpLoopRange) {
+    linalgOp.emitError("unable to find loop range for operation");
+    return llvm::None;
+  }
+  SmallVector<unsigned, 4> numFoldedDims(fusedIndexMap.getNumDims(), 1);
   SmallVector<SmallVector<int64_t, 4>, 4> expandedDimsShape(
-      foldedType.getRank());
+      fusedIndexMap.getNumDims());
   auto reassociationMaps = reshapeOp.getReassociationMaps();
   for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
@@ -467,6 +469,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
         expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
     expandedDimsShape[pos].assign(shape.begin(), shape.end());
   }
+  // The remaining dimensions remain the same.
+  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
+    if (expandedDimsShape[i].empty())
+      expandedDimsShape[i] = {(*origOpLoopRange)[i]};
 
   if (isa<IndexedGenericOp>(linalgOp.getOperation())) {
     // For indexed generic op, the region contains arguments that represent the
@@ -476,6 +482,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
     // front) are statically know. For dynamic case, we would need shape
     // information on these dimensions to get these.
     for (auto &expandedShape : expandedDimsShape) {
+      if (expandedShape.size() == 1)
+        continue;
       for (int64_t expandedDimShape : llvm::make_range(
                std::next(expandedShape.begin()), expandedShape.end())) {
         if (ShapedType::isDynamic(expandedDimShape)) {
index 43f4016..8e60312 100644 (file)
@@ -104,13 +104,18 @@ SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
     auto shape = v.getType().cast<ShapedType>().getShape();
     res.append(shape.begin(), shape.end());
   }
+  if (linalgOp.getNumInitTensors())
+    return res;
+  for (Value v : linalgOp.getOperation()->getResults()) {
+    auto shape = v.getType().cast<ShapedType>().getShape();
+    res.append(shape.begin(), shape.end());
+  }
   return res;
 }
 
 Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
   SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
-  AffineMap invertedMap =
-      inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
+  AffineMap invertedMap = linalgOp.getShapesToLoopsMap();
   if (!invertedMap)
     return {};
   return invertedMap.compose(viewSizes);
index 1f201f7..66e07cc 100644 (file)
@@ -344,3 +344,97 @@ func @reshape_as_consumer_permutation
 //       CHECK:       %[[T9:.+]] = addi %[[T7]], %[[T8]]
 //       CHECK:       %[[T10:.+]] = index_cast %[[ARG7]]
 //       CHECK:       %[[T11:.+]] = addi %[[T9]], %[[T10]]
+
+// -----
+
+func @reshape_as_producer_projected_permutation
+  (%arg0 : tensor<33x8x?xi32>) -> tensor<264x?x4xi32> {
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>,
+                                    affine_map<(d0, d1, d2) -> (d2)>]
+    : tensor<33x8x?xi32> into tensor<264x?xi32>
+  %1 = linalg.indexed_generic
+    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
+                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+     iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<264x?xi32>) {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32):  // no predecessors
+    %2 = index_cast %arg1 : index to i32
+    %3 = addi %arg4, %2 : i32
+    %4 = index_cast %arg2 : index to i32
+    %5 = addi %3, %4 : i32
+    %6 = index_cast %arg3 : index to i32
+    %7 = addi %5, %6 : i32
+    linalg.yield %7 : i32
+  } -> tensor<264x?x4xi32>
+  return %1 : tensor<264x?x4xi32>
+}
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)>
+//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//       CHECK: @reshape_as_producer_projected_permutation
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<33x8x?xi32>
+//       CHECK:   %[[RES:.+]] = linalg.indexed_generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:     ins(%[[ARG0]] : tensor<33x8x?xi32>)
+//       CHECK:   ^{{.+}}(
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9]+]]: i32)
+//       CHECK:       %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG1]], %[[ARG2]])
+//       CHECK:       %[[T1:.+]] = index_cast %[[T0]] : index to i32
+//       CHECK:       %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32
+//       CHECK:       %[[T3:.+]] = index_cast %[[ARG3]] : index to i32
+//       CHECK:       %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32
+//       CHECK:       %[[T5:.+]] = index_cast %[[ARG4]] : index to i32
+//       CHECK:       %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32
+//       CHECK:       linalg.yield %[[T6]] : i32
+//       CHECK:    %[[RES2:.+]] = linalg.tensor_reshape %[[RES]]
+//  CHECK-SAME:      [#[[MAP3]], #[[MAP4]], #[[MAP5]]]
+//  CHECK-SAME:    : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
+//       CHECK:  return %[[RES2]] : tensor<264x?x4xi32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
+                                                   %arg1 : tensor<?x?xf32>) ->
+                                                   tensor<?x?x4x5xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map0, #map1],
+     iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+//      CHECK: func @generic_op_reshape_consumer_fusion_projected
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
+//      CHECK:   return %[[T2]] : tensor<?x?x4x5xf32>