[mlir][tosa] Fix tosa.slice shape inference for ShapedType:kDynamicShape
authorRob Suderman <suderman@google.com>
Fri, 18 Nov 2022 19:43:55 +0000 (11:43 -0800)
committerRob Suderman <suderman@google.com>
Fri, 18 Nov 2022 19:44:05 +0000 (11:44 -0800)
Change for kDynamicShape means the size needs to be updated to a new value
for slice operation shape inference. Landing fix.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D138314

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

index 2164fd7..e33d2bc 100644 (file)
@@ -590,6 +590,12 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
   return success();
 }
 
+static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
+  return to_vector(llvm::map_range(shape, [](int64_t dim) {
+    return dim == -1 ? ShapedType::kDynamicSize : dim;
+  }));
+}
+
 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -601,7 +607,8 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
   }
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(
+    convertToMlirShape(outputShape)));
   return success();
 }
 
@@ -655,11 +662,6 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
   return success();
 }
 
-static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
-  return to_vector(llvm::map_range(shape, [](int64_t dim) {
-    return dim == -1 ? ShapedType::kDynamicSize : dim;
-  }));
-}
 
 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
index 311851a..70a07db 100644 (file)
@@ -539,6 +539,15 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_slice_dynamic
+func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
+  // CHECK: "tosa.slice"(%arg0) {size = [7, -1, 1], start = [1, 0, 0]} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>
+  %0 = "tosa.slice"(%arg0) {size = [7, -1, 1], start = [1, 0, 0]} : (tensor<10x?x2xf32>) -> tensor<?x?x?xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_tile
 func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
   // CHECK: "tosa.tile"(%arg0) {multiples = [2, 1, 5]} : (tensor<2x3x?xi32>) -> tensor<4x3x?xi32>