[mlir][tosa] Add broadcasting case for tosa.resize to linalg implementation
authorRob Suderman <suderman@google.com>
Tue, 3 Jan 2023 22:14:43 +0000 (14:14 -0800)
committerRob Suderman <suderman@google.com>
Tue, 3 Jan 2023 22:29:06 +0000 (14:29 -0800)
When lowering tosa.resize it is possible there is an unary input dimension.
Lowering to a new tosa.resize and explicit broadcast simplifies the
tosa.resize operation to avoid recomputing the identical broadcasted values.

This change reworks the broadcast optimization reuse the tosa.resize generic
implementation.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D139963

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir

index c4d8b68..199d81d 100644 (file)
@@ -1329,10 +1329,10 @@ public:
   }
 };
 
-// Handle the case where the resize operation is a regular broadcast. We
-// perform this part separately to avoid generating Extract operations which
-// are difficult to vectorize / optimize.
-class BroadcastResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
+// Handle the resize case where the input is a 1x1 image. This case
+// can entirely avoiding having extract operations which target much
+// more difficult to optimize away.
+class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
 public:
   using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
 
@@ -1343,63 +1343,61 @@ public:
     auto input = op.getInput();
     auto inputTy = input.getType().cast<RankedTensorType>();
     auto resultTy = op.getType().cast<RankedTensorType>();
+    const bool isBilinear = op.getMode() == "BILINEAR";
 
-    auto imageH = inputTy.getDimSize(1);
-    auto imageW = inputTy.getDimSize(2);
+    auto inputH = inputTy.getDimSize(1);
+    auto inputW = inputTy.getDimSize(2);
+    auto outputH = resultTy.getDimSize(1);
+    auto outputW = resultTy.getDimSize(2);
 
-    if (imageH != 1 || imageW != 1) {
+    if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
       return rewriter.notifyMatchFailure(
-          op, "tosa.resize is not a pure broadcast operation");
-    }
+          op, "tosa.resize is not a pure 1x1->1x1 image operation");
 
     // TODO(suderman): These string values should be declared the TOSA dialect.
     if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
       return rewriter.notifyMatchFailure(
           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
 
-    const bool isBilinear = op.getMode() == "BILINEAR";
+    if (inputTy == resultTy) {
+      rewriter.replaceOp(op, input);
+      return success();
+    }
 
     SmallVector<int32_t> scale;
     getValuesFromIntArrayAttribute(op.getScale(), scale);
 
-    // Collapse the 1 dimensions away.
-    SmallVector<ReassociationExprs, 4> collapseMap(2);
-    collapseMap[0].push_back(builder.getAffineDimExpr(0));
-    collapseMap[1].push_back(builder.getAffineDimExpr(1));
-    collapseMap[1].push_back(builder.getAffineDimExpr(2));
-    collapseMap[1].push_back(builder.getAffineDimExpr(3));
+    // Collapse the unit width and height away.
+    SmallVector<ReassociationExprs, 4> reassociationMap(2);
+    reassociationMap[0].push_back(builder.getAffineDimExpr(0));
+    reassociationMap[1].push_back(builder.getAffineDimExpr(1));
+    reassociationMap[1].push_back(builder.getAffineDimExpr(2));
+    reassociationMap[1].push_back(builder.getAffineDimExpr(3));
 
     auto collapseTy =
         RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
                               inputTy.getElementType());
-    Value collapse =
-        builder.create<tensor::CollapseShapeOp>(collapseTy, input, collapseMap);
+    Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
+                                                             reassociationMap);
 
-    // Broadcast input to the output shape.
+    // Get any dynamic shapes that appear in the input format.
     llvm::SmallVector<Value> outputDynSize;
     if (inputTy.isDynamicDim(0))
       outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
-
     if (inputTy.isDynamicDim(3))
       outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
 
-    llvm::SmallVector<AffineExpr> inputExprs{
-        rewriter.getAffineDimExpr(0),
-        rewriter.getAffineDimExpr(3),
-    };
-
-    auto inputMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
-                                   inputExprs, builder.getContext());
-    auto resultMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
-    SmallVector<utils::IteratorType> iterators(4,
-                                               utils::IteratorType::parallel);
-
+    // Generate the elementwise operation for casting scaling the input value.
+    auto genericTy = collapseTy.clone(resultTy.getElementType());
     Value empty = builder.create<tensor::EmptyOp>(
-        resultTy.getShape(), resultTy.getElementType(), outputDynSize);
+        genericTy.getShape(), resultTy.getElementType(), outputDynSize);
+    auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
+    SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
+                                               utils::IteratorType::parallel);
 
     auto generic = builder.create<linalg::GenericOp>(
-        resultTy, ValueRange{collapse}, ValueRange{empty},
-        ArrayRef<AffineMap>{inputMap, resultMap}, iterators,
+        genericTy, ValueRange{collapse}, ValueRange{empty},
+        ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
         [=](OpBuilder &b, Location loc, ValueRange args) {
           Value value = args[0];
           // This is the quantized case.
@@ -1423,7 +1421,107 @@ public:
           b.create<linalg::YieldOp>(loc, value);
         });
 
-    rewriter.replaceOp(op, generic.getResult(0));
+    rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+        op, resultTy, generic.getResults()[0], reassociationMap);
+    return success();
+  }
+};
+
+// TOSA resize with width or height of 1 may be broadcasted to a wider
+// dimension. This is done by materializing a new tosa.resize without
+// the broadcasting behavior, and an explicit broadcast afterwards.
+class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
+public:
+  using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ResizeOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    ImplicitLocOpBuilder builder(loc, rewriter);
+    auto input = op.getInput();
+    auto inputTy = input.getType().dyn_cast<RankedTensorType>();
+    auto resultTy = op.getType().dyn_cast<RankedTensorType>();
+
+    if (!inputTy || !resultTy)
+      return rewriter.notifyMatchFailure(op,
+                                         "requires ranked input/output types");
+
+    auto batch = inputTy.getDimSize(0);
+    auto channels = inputTy.getDimSize(3);
+    auto inputH = inputTy.getDimSize(1);
+    auto inputW = inputTy.getDimSize(2);
+    auto outputH = resultTy.getDimSize(1);
+    auto outputW = resultTy.getDimSize(2);
+
+    if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
+      return rewriter.notifyMatchFailure(
+          op, "tosa.resize has no broadcasting behavior");
+
+    // For any dimension that is broadcastable we generate a width of 1
+    // on the output.
+    llvm::SmallVector<int64_t> resizeShape;
+    resizeShape.push_back(batch);
+    resizeShape.push_back(inputH == 1 ? 1 : outputH);
+    resizeShape.push_back(inputW == 1 ? 1 : outputW);
+    resizeShape.push_back(channels);
+
+    auto resizeTy = resultTy.clone(resizeShape);
+    auto resize =
+        builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
+
+    // Collapse an unit result dims.
+    SmallVector<ReassociationExprs, 4> reassociationMap(2);
+    reassociationMap[0].push_back(builder.getAffineDimExpr(0));
+    reassociationMap.back().push_back(builder.getAffineDimExpr(1));
+    if (inputH != 1)
+      reassociationMap.push_back({});
+    reassociationMap.back().push_back(builder.getAffineDimExpr(2));
+    if (inputW != 1)
+      reassociationMap.push_back({});
+    reassociationMap.back().push_back(builder.getAffineDimExpr(3));
+
+    llvm::SmallVector<int64_t> collapseShape{batch};
+    if (inputH != 1)
+      collapseShape.push_back(outputH);
+    if (inputW != 1)
+      collapseShape.push_back(outputW);
+    collapseShape.push_back(channels);
+
+    auto collapseTy = resultTy.clone(collapseShape);
+    Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
+                                                             reassociationMap);
+
+    // Broadcast the collapsed shape to the output result.
+    llvm::SmallVector<Value> outputDynSize;
+    if (inputTy.isDynamicDim(0))
+      outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
+    if (inputTy.isDynamicDim(3))
+      outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
+
+    SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
+                                               utils::IteratorType::parallel);
+    Value empty = builder.create<tensor::EmptyOp>(
+        resultTy.getShape(), resultTy.getElementType(), outputDynSize);
+
+    SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
+    if (inputH != 1)
+      inputExprs.push_back(rewriter.getAffineDimExpr(1));
+    if (inputW != 1)
+      inputExprs.push_back(rewriter.getAffineDimExpr(2));
+    inputExprs.push_back(rewriter.getAffineDimExpr(3));
+
+    auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
+                                   inputExprs, rewriter.getContext());
+
+    auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        op, resultTy, ValueRange{collapse}, ValueRange{empty},
+        ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
+        [=](OpBuilder &b, Location loc, ValueRange args) {
+          Value value = args[0];
+          b.create<linalg::YieldOp>(loc, value);
+        });
+
     return success();
   }
 };
@@ -2226,8 +2324,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
   // We have multiple resize coverters to handle degenerate cases.
   patterns->add<GenericResizeConverter>(patterns->getContext(),
                                         /*benefit=*/100);
-  patterns->add<BroadcastResizeConverter>(patterns->getContext(),
-                                          /*benefit=*/200);
+  patterns->add<ResizeUnaryConverter>(patterns->getContext(),
+                                      /*benefit=*/200);
+  patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
+                                            /*benefit=*/300);
 
   patterns->add<
       // clang-format off
index 382dae5..9a6067f 100644 (file)
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -o -| FileCheck %s
 
-// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
-// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: @broadcast_resize_nearest_fp
-func.func @broadcast_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x15x13x7xf32> {
-  // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
-  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
-  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xf32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK-SAME: indexing_maps = [#map, #map1]
-  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xf32>)
-  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xf32>)
-  // CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-  // CHECK:   linalg.yield %[[IN]]
-  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xf32>) -> tensor<3x15x13x7xf32>
-
-  // CHECK: return %[[GENERIC]]
-  return %resize : tensor<3x15x13x7xf32>
+// CHECK-LABEL: @unary_resize_nearest_fp
+func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
+  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
+  // CHECK: return %arg0
+  return %resize : tensor<3x1x1x7xf32>
 }
 
 // -----
 
-// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
-// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: @broadcast_resize_bilinear_fp
-func.func @broadcast_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x15x13x7xf32> {
-  // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
-  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
-  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xf32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK-SAME: indexing_maps = [#map, #map1]
-  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xf32>)
-  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xf32>)
-  // CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-  // CHECK:   linalg.yield %[[IN]]
-  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xf32>) -> tensor<3x15x13x7xf32>
-
-  // CHECK: return %[[GENERIC]]
-  return %resize : tensor<3x15x13x7xf32>
+// CHECK-LABEL: @unary_resize_bilinear_fp
+func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
+  // CHECK: return %arg0
+  return %resize : tensor<3x1x1x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @unary_resize_nearest_i8
+func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8> {
+  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 1, 3, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8>
+  // CHECK: return %arg0
+  return %resize : tensor<3x1x1x7xi8>
 }
 
 // -----
 
-// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
-// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: @broadcast_resize_nearest_i8
-func.func @broadcast_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi8> {
+// CHECK-LABEL: @broadcast_resize_nearest_f32
+func.func @broadcast_resize_nearest_f32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x5x7xf32> {
   // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
-  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
-  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xi8>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK-SAME: indexing_maps = [#map, #map1]
-  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xi8>)
-  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xi8>)
-  // CHECK-NEXT: ^bb0(%[[IN:.+]]: i8, %[[OUT:.+]]: i8):
-  // CHECK:   linalg.yield %[[IN]]
-  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi8>
-
-  // CHECK: return %[[GENERIC]]
-  return %resize : tensor<3x15x13x7xi8>
+  // CHECK-NEXT{literal}: [[0], [1, 2, 3]] : tensor<3x1x1x7xf32> into tensor<3x7xf32>
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x1x5x7xf32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK-SAME: indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xf32>) outs(%[[EMPTY]] : tensor<3x1x5x7xf32>)
+  // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:   linalg.yield %[[IN]] : f32
+  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 1, 3, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xf32>) -> tensor<3x1x5x7xf32>
+
+ // CHECK: return %[[GENERIC]]
+  return %resize : tensor<3x1x5x7xf32>
 }
 
 // -----
 
-// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
-// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: @broadcast_resize_nearest_i32
-func.func @broadcast_resize_nearest_i32(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi32> {
+// CHECK-LABEL: @broadcast_resize_bilinear_i8
+func.func @broadcast_resize_bilinear_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x4x5x7xi32> {
   // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
-  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
-  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xi32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK-SAME: indexing_maps = [#map, #map1]
-  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xi8>)
-  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xi32>)
-  // CHECK-NEXT: ^bb0(%[[IN:.+]]: i8, %[[OUT:.+]]: i32):
+  // CHECK-SAME{literal}: [[0], [1, 2, 3]] : tensor<3x1x1x7xi8> into tensor<3x7xi8>
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x7xi32>
+  // CHECK: %[[RESIZE:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xi8>) outs(%[[EMPTY]] : tensor<3x7xi32>)
+  // CHECK: ^bb0(%[[IN:.+]]: i8, %[[OUT:.+]]: i32):
   // CHECK:   %[[EXT:.+]] = arith.extsi %[[IN]] : i8 to i32
-  // CHECK:   linalg.yield %[[EXT]]
-  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi32>
-
-  // CHECK: return %[[GENERIC]]
-  return %resize : tensor<3x15x13x7xi32>
+  // CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : i32
+  // CHECK:   %[[MUL:.+]] = arith.muli %[[EXT]], %[[C2]] : i32
+  // CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : i32
+  // CHECK:   %[[OUT:.+]] = arith.muli %[[MUL]], %[[C3]] : i32
+  // CHECK:   linalg.yield %[[OUT]] : i32
+  // CHECK: } -> tensor<3x7xi32>
+  // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %1
+  // CHECK-SAME{literal}: [[0], [1, 2, 3]] : tensor<3x7xi32> into tensor<3x1x1x7xi32>
+  // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %expanded
+  // CHECK-SAME{literal}:[[0], [1, 2, 3]] : tensor<3x1x1x7xi32> into tensor<3x7xi32>
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x4x5x7xi32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xi32>) outs(%[[EMPTY]] : tensor<3x4x5x7xi32>) {
+  // CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
+  // CHECK:   linalg.yield %[[IN]] : i32
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 1, 3, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xi8>) -> tensor<3x4x5x7xi32>
+
+  // CHECK: return %[[BROADCAST]]
+  return %resize : tensor<3x4x5x7xi32>
 }
 
 // -----
 
-// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
-// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: @broadcast_resize_bilinear_i32
-func.func @broadcast_resize_bilinear_i32(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi32> {
+// CHECK-LABEL: @unary_resize_bilinear_i32
+func.func @unary_resize_bilinear_i32(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi32> {
   // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
-  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
-  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xi32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK-SAME: indexing_maps = [#map, #map1]
-  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xi8>)
-  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xi32>)
-  // CHECK-NEXT: ^bb0(%[[IN:.+]]: i8, %[[OUT:.+]]: i32):
-  // CHECK: %[[EXT:.+]] = arith.extsi %[[IN]] : i8 to i32
-  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32
-  // CHECK: %[[MUL1:.+]] = arith.muli %[[EXT]], %[[C2]] : i32
-  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
-  // CHECK: %[[MUL2:.+]] = arith.muli %[[MUL1]], %[[C1]] : i32
-  // CHECK: linalg.yield %[[MUL2]]
-  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi32>
-
-  // CHECK: return %[[GENERIC]]
-  return %resize : tensor<3x15x13x7xi32>
+  // CHECK-SAME{literal}: [[0], [1, 2, 3]] : tensor<3x1x1x7xi8> into tensor<3x7xi8>
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x7xi32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK-SAME: indexing_maps = [#map, #map]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xi8>) outs(%[[EMPTY]] : tensor<3x7xi32>) {
+  // CHECK: ^bb0(%[[IN:.+]]: i8, %[[OUT:.+]]: i32):
+  // CHECK:   %[[EXT:.+]] = arith.extsi %[[IN]] : i8 to i32
+  // CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : i32
+  // CHECK:   %[[MUL0:.+]] = arith.muli %[[EXT]], %[[C2]] : i32
+  // CHECK-DAG:   %[[C1:.+]] = arith.constant 2 : i32
+  // CHECK:   %4 = arith.muli %3, %[[C1]] : i32
+  // CHECK:   linalg.yield %4 : i32
+  // CHECK: } -> tensor<3x7xi32>
+  // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[GENERIC:.+]]
+  // CHECK-SAME{literal} [[0], [1, 2, 3]] : tensor<3x7xi32> into tensor<3x1x1x7xi32>
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 1, 2, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi32>
+
+  // CHECK: return %[[EXPAND]]
+  return %resize : tensor<3x1x1x7xi32>
 }
 
 // -----