[tosa][mlir] Refactor tosa.reshape lowering to linalg for dynamic cases.
authornatashaknk <natashaknk@google.com>
Mon, 15 Nov 2021 23:10:36 +0000 (15:10 -0800)
committerRob Suderman <rob.suderman@gmail.com>
Mon, 15 Nov 2021 23:31:37 +0000 (15:31 -0800)
Split tosa.reshape into three individual lowerings: collapse, expand and a
combination of both. Add simple dynamic shape support.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D113936

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index e90d153..f4470d2 100644 (file)
@@ -946,6 +946,112 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
   return success();
 }
 
+static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
+                                  ArrayRef<int64_t> rhsShape,
+                                  SmallVector<int64_t> &intermediateShape,
+                                  bool isDynamic) {
+  if (isDynamic) {
+    // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
+    intermediateShape = {-1};
+    return true;
+  }
+
+  if (lhsShape.empty() || rhsShape.empty()) {
+    intermediateShape = {};
+    return true;
+  }
+
+  unsigned currLhsDim = 0, currRhsDim = 0;
+  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
+    int64_t rhsSize = rhsShape[currRhsDim];
+    int64_t lhsSize = lhsShape[currLhsDim];
+    while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
+           currRhsDim < rhsShape.size()) {
+      if (lhsSize < rhsSize) {
+        currLhsDim++;
+        lhsSize *= lhsShape[currLhsDim];
+      } else {
+        currRhsDim++;
+        rhsSize *= rhsShape[currRhsDim];
+      }
+    }
+    if (lhsSize == rhsSize) {
+      intermediateShape.push_back(lhsSize);
+    }
+    currRhsDim++;
+    currLhsDim++;
+  }
+
+  // If the iterators didn't reach the end and their leftover dimensions are not
+  // equal to 1 an intermediate shape was not found.
+  while (currLhsDim < lhsShape.size()) {
+    if (lhsShape[currLhsDim++] != 1) {
+      return false;
+    }
+  }
+
+  while (currRhsDim < rhsShape.size()) {
+    if (rhsShape[currRhsDim++] != 1) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+static bool createReassociationMapsForCollapse(
+    PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
+    ArrayRef<int64_t> dstShape,
+    SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
+
+  // If the shape is dynamic, create a map for collapsing into one dimension.
+  if (isDynamic) {
+    SmallVector<AffineExpr, 2> exprs;
+    for (int i = 0, s = srcShape.size(); i < s; ++i)
+      exprs.push_back(rewriter.getAffineDimExpr(i));
+    reassociationMap = {exprs};
+    return true;
+  }
+
+  if (dstShape.empty()) {
+    reassociationMap = {};
+    return true;
+  }
+
+  reassociationMap.resize(dstShape.size());
+  unsigned currSrcDim = 0, currDstDim = 0;
+  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
+    int64_t dstSize = dstShape[currDstDim];
+    int64_t srcSize = srcShape[currSrcDim];
+    while (srcSize < dstSize && currSrcDim < srcShape.size()) {
+      reassociationMap[currDstDim].push_back(
+          rewriter.getAffineDimExpr(currSrcDim++));
+      srcSize *= srcShape[currSrcDim];
+    }
+    if (srcSize == dstSize) {
+      reassociationMap[currDstDim].push_back(
+          rewriter.getAffineDimExpr(currSrcDim++));
+      // If the next dim in collapsedShape is not 1, treat subsequent dims in
+      // expandedShape which are 1 to be collapsed.
+      if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
+        while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
+          reassociationMap[currDstDim].push_back(
+              rewriter.getAffineDimExpr(currSrcDim++));
+        }
+      }
+    }
+    currDstDim++;
+  }
+
+  // If both iterators didn't reach the end, we have leftover dimentions which
+  // implies that we have a mismatch in shape.
+  if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) {
+    return false;
+  }
+
+  return true;
+}
+
 namespace {
 
 template <typename SrcOp>
@@ -1534,7 +1640,7 @@ public:
   }
 };
 
-class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
+class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
 public:
   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
 
@@ -1543,103 +1649,116 @@ public:
                   ConversionPatternRewriter &rewriter) const final {
     ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
     ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+    bool isDynamic = !operandTy.hasStaticShape();
+
+    if (isDynamic && resultTy.getRank() != 1) {
+      return rewriter.notifyMatchFailure(
+          reshape, "Cannot collapse dynamic dims to more than one dimension");
+    }
 
     if (operandTy == resultTy) {
       rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
       return success();
     }
 
-    if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
-      return failure();
+    SmallVector<ReassociationExprs, 4> reassociationMap;
+    if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
+                                            resultTy.getShape(),
+                                            reassociationMap, isDynamic)) {
+      return rewriter.notifyMatchFailure(
+          reshape,
+          "tosa.reshape Attempting to collapse into an incompatible shape");
+    }
 
-    // Compute the reassociation maps for the linalg operation.
-    ArrayRef<int64_t> expandedShape =
-        (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
-                                                  : resultTy.getShape());
-    ArrayRef<int64_t> collapsedShape =
-        (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
-                                                  : operandTy.getShape());
-    unsigned currSrcDim = 0, currDstDim = 0;
-    SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
-
-    // First scan all dimensions in the source shapes to see whether we have a
-    // perfect case where consecutive dimensions in source are collapsed. For
-    // such case we can just generate one single linalg.reshape.
-    bool isCollapsingSource = true;
-    while (currSrcDim < expandedShape.size() &&
-           currDstDim < collapsedShape.size()) {
-      int64_t dstSize = collapsedShape[currDstDim];
-      int64_t srcSize = expandedShape[currSrcDim];
-      while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
-        reassociationMap[currDstDim].push_back(
-            rewriter.getAffineDimExpr(currSrcDim++));
-        srcSize *= expandedShape[currSrcDim];
-      }
-      if (srcSize == dstSize) {
-        reassociationMap[currDstDim].push_back(
-            rewriter.getAffineDimExpr(currSrcDim++));
-        // If the next dim in collapsedShape is not 1, treat subsequent dims in
-        // expandedShape which are 1 to be collapsed.
-        if (currDstDim == collapsedShape.size() - 1 ||
-            collapsedShape[currDstDim + 1] != 1) {
-          while (currSrcDim < expandedShape.size() &&
-                 expandedShape[currSrcDim] == 1) {
-            reassociationMap[currDstDim].push_back(
-                rewriter.getAffineDimExpr(currSrcDim++));
-          }
-        }
-      } else {
-        isCollapsingSource = false;
-        break;
-      }
-      currDstDim++;
+    SmallVector<int64_t> intermediateShape;
+    if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+                               intermediateShape, isDynamic)) {
+      return rewriter.notifyMatchFailure(
+          reshape, "tosa.reshape Cannot collapse into given shape");
     }
 
-    // Check if any remaining dimensions exist. If either is rank-0 we only
-    // require the directly lowering.
-    if (currSrcDim != expandedShape.size() ||
-        currDstDim != collapsedShape.size())
-      isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
-
-    // Otherwise, we need to first reduce all source dimensions into one and
-    // then expand to the destination dimensions.
-    if (!isCollapsingSource) {
-      auto getIdentityExprs = [&rewriter](int n) {
-        SmallVector<AffineExpr, 4> exprs;
-        for (int i = 0; i < n; ++i)
-          exprs.push_back(rewriter.getAffineDimExpr(i));
-        return exprs;
-      };
-      Location loc = reshape.getLoc();
-      int64_t totalElems =
-          std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
-                          std::multiplies<int64_t>());
-      auto elemTy = operandTy.getElementType();
-      SmallVector<ReassociationExprs, 4> collapsingMap = {
-          // Use operandTy here because we need to collapse all operands
-          // dimensions.
-          getIdentityExprs(operandTy.getShape().size())};
-      SmallVector<ReassociationExprs, 4> expandingMap = {
-          // Use resultTy here because we need to expand to all result
-          // dimensions.
-          getIdentityExprs(resultTy.getShape().size())};
-
-      auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
-      Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
-          loc, collapsedTy, adaptor.getOperands()[0], collapsingMap);
-      rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
-          reshape, resultTy, collapsedOp, expandingMap);
+    rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
+        reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+    return success();
+  }
+};
+
+class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
+public:
+  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
+    ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+    bool isDynamic = !operandTy.hasStaticShape();
 
+    if (operandTy == resultTy) {
+      rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
       return success();
     }
 
-    if (resultTy.getRank() <
-        adaptor.getOperands()[0].getType().cast<ShapedType>().getRank())
-      rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
-          reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
-    else
-      rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
-          reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+    if (isDynamic && operandTy.getRank() != 1) {
+      return rewriter.notifyMatchFailure(
+          reshape, "Cannot expand dynamic dims from more than one dimension");
+    }
+
+    SmallVector<ReassociationExprs, 4> reassociationMap;
+    if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
+                                            operandTy.getShape(),
+                                            reassociationMap, isDynamic)) {
+      return rewriter.notifyMatchFailure(
+          reshape,
+          "tosa.reshape Attempting to expand into an incompatible shape");
+    }
+
+    SmallVector<int64_t> intermediateShape;
+    if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+                               intermediateShape, isDynamic) ||
+        intermediateShape != operandTy.getShape()) {
+      return rewriter.notifyMatchFailure(
+          reshape, "tosa.reshape Cannot expand into given shape");
+    }
+    rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
+        reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+    return success();
+  }
+};
+
+class ReshapeConverterCollapseExpand
+    : public OpConversionPattern<tosa::ReshapeOp> {
+public:
+  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
+    ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+    bool isDynamic = !operandTy.hasStaticShape();
+
+    if (operandTy == resultTy) {
+      rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
+      return success();
+    }
+
+    SmallVector<int64_t> intermediateShape;
+    if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
+                               intermediateShape, isDynamic)) {
+      return rewriter.notifyMatchFailure(
+          reshape, "tosa.reshape Cannot identify an intermediate shape between "
+                   "the given two shapes");
+    }
+
+    Value collapse = rewriter.create<tosa::ReshapeOp>(
+        reshape.getLoc(),
+        RankedTensorType::get(intermediateShape,
+                              reshape.getType().getElementType()),
+        adaptor.input1());
+    Value expand =
+        rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
+    rewriter.replaceOp(reshape, expand);
 
     return success();
   }
@@ -3072,7 +3191,9 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       TransposeConvConverter,
       GatherConverter,
       PadConverter,
-      ReshapeConverter,
+      ReshapeConverterCollapse,
+      ReshapeConverterExpand,
+      ReshapeConverterCollapseExpand,
       RescaleConverter,
       ResizeConverter,
       ReverseConverter,
index d072808..2e25ad9 100644 (file)
@@ -541,6 +541,16 @@ func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
 
 // -----
 
+// CHECK-LABEL: @test_reshape_downrank_dyn
+func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<2x?xf32>) -> tensor<?xf32>
+  // CHECK: return [[RESHAPE]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @test_reshape_uprank
 func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
   // CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
@@ -551,6 +561,16 @@ func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
 
 // -----
 
+// CHECK-LABEL: @test_reshape_uprank_dyn
+func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?xf32>) -> tensor<2x?xf32>
+  // CHECK: return [[RESHAPE]]
+  return %0 : tensor<2x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @test_reshape_samerank
 func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
   // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
@@ -563,6 +583,18 @@ func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
 
 // -----
 
+// CHECK-LABEL: @test_reshape_samerank_dyn
+func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
+  // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
+  // CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?x2xf32>) -> tensor<2x?xf32>
+  // CHECK-NEXT: return %[[RESHAPE2]]
+  return %0 : tensor<2x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @test_reshape_downrank_6D
 func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
   // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
@@ -572,6 +604,16 @@ func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77
 
 // -----
 
+// CHECK-LABEL: @test_reshape_downrank_6D_dyn
+func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
+  // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]]
+  // CHECK: linalg.tensor_expand_shape %0 {{\[}}[0, 1, 2]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
+  return %0 : tensor<?x5x77xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @test_identity
 func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) {
   %0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>