[tosa] Lower tosa.slice to tensor.slice for dynamic case
authorRob Suderman <suderman@google.com>
Wed, 15 Jun 2022 16:54:23 +0000 (09:54 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Wed, 15 Jun 2022 16:54:36 +0000 (09:54 -0700)
Existing slice lowering only supporting static shapes.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D127704

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index c02108e..c8c326d 100644 (file)
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -27,14 +28,32 @@ public:
 
   LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
                                 PatternRewriter &rewriter) const final {
+    Location loc = sliceOp.getLoc();
     Value input = sliceOp.input();
     SmallVector<int64_t> strides;
+    auto starts = sliceOp.start();
+    auto sizes = sliceOp.size();
     strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
 
-    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
-        sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
-        ValueRange({}), sliceOp.start(), sliceOp.size(),
-        rewriter.getI64ArrayAttr(strides));
+    SmallVector<Value> dynSizes;
+    for (auto i : llvm::enumerate(sizes)) {
+      int64_t size = i.value().cast<IntegerAttr>().getInt();
+      size_t index = i.index();
+      if (size != ShapedType::kDynamicSize)
+        continue;
+
+      auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
+      auto offset = rewriter.create<arith::ConstantOp>(
+          loc,
+          rewriter.getIndexAttr(starts[index].cast<IntegerAttr>().getInt()));
+      dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
+    }
+
+    auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+        sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
+        ValueRange({}), starts, sizes, rewriter.getI64ArrayAttr(strides));
+
+    rewriter.replaceOp(sliceOp, newSliceOp.getResult());
     return success();
   }
 };
index 6fe862b..08d5c7d 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "../PassDetail.h"
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
@@ -31,6 +32,7 @@ public:
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addIllegalOp<tosa::SliceOp>();
+    target.addLegalDialect<arith::ArithmeticDialect>();
     target.addLegalDialect<tensor::TensorDialect>();
 
     mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
index 12eb51f..15a4bcd 100644 (file)
@@ -6,3 +6,16 @@ func.func @slice(%arg0: tensor<6xf32>) ->() {
   %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>)  -> (tensor<1xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABLE: func @slice_dyn
+func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]]
+  // CHECK: %2 = tensor.extract_slice %arg0[2] [%[[SUB]]] [1]
+  %0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor<?xf32>)  -> (tensor<?xf32>)
+  return %0 : tensor<?xf32>
+}
index 23308f9..e918670 100644 (file)
@@ -7897,6 +7897,7 @@ cc_library(
         "lib/Conversion/TosaToTensor",
     ],
     deps = [
+        ":ArithmeticDialect",
         ":ConversionPassIncGen",
         ":FuncDialect",
         ":IR",