From 640973f2b99b9b9eb85be096626fd0a7fc7d1dfe Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 15 Jun 2022 09:54:23 -0700 Subject: [PATCH] [tosa] Lower tosa.slice to tensor.slice for dynamic case Existing slice lowering only supporting static shapes. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D127704 --- mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp | 27 ++++++++++++++++++---- .../Conversion/TosaToTensor/TosaToTensorPass.cpp | 2 ++ .../Conversion/TosaToTensor/tosa-to-tensor.mlir | 13 +++++++++++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 + 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index c02108e..c8c326d 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" @@ -27,14 +28,32 @@ public: LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const final { + Location loc = sliceOp.getLoc(); Value input = sliceOp.input(); SmallVector strides; + auto starts = sliceOp.start(); + auto sizes = sliceOp.size(); strides.resize(sliceOp.getType().template cast().getRank(), 1); - rewriter.replaceOpWithNewOp( - sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}), - ValueRange({}), sliceOp.start(), sliceOp.size(), - rewriter.getI64ArrayAttr(strides)); + SmallVector dynSizes; + for (auto i : llvm::enumerate(sizes)) { + int64_t size = i.value().cast().getInt(); + size_t index = i.index(); + if (size != ShapedType::kDynamicSize) + continue; + + auto dim = rewriter.create(loc, input, index); + auto offset = rewriter.create( + loc, + rewriter.getIndexAttr(starts[index].cast().getInt())); + dynSizes.push_back(rewriter.create(loc, dim, offset)); + } + + auto newSliceOp = rewriter.create( + sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, + ValueRange({}), starts, sizes, rewriter.getI64ArrayAttr(strides)); + + rewriter.replaceOp(sliceOp, newSliceOp.getResult()); return success(); } }; diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp index 6fe862b..08d5c7d 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" @@ -31,6 +32,7 @@ public: RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addIllegalOp(); + target.addLegalDialect(); target.addLegalDialect(); mlir::tosa::populateTosaToTensorConversionPatterns(&patterns); diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index 12eb51f..15a4bcd 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -6,3 +6,16 @@ func.func @slice(%arg0: tensor<6xf32>) ->() { %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>) return } + +// ----- + +// CHECK-LABLE: func @slice_dyn +func.func @slice_dyn(%arg0: tensor) -> (tensor) { + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]] + // CHECK: %2 = tensor.extract_slice %arg0[2] [%[[SUB]]] [1] + %0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor) -> (tensor) + return %0 : tensor +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 23308f9..e9186702 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7897,6 +7897,7 @@ cc_library( "lib/Conversion/TosaToTensor", ], deps = [ + ":ArithmeticDialect", ":ConversionPassIncGen", ":FuncDialect", ":IR", -- 2.7.4