From 6b4b63a832f105039442fc983d0b309abe5261d5 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 30 May 2023 10:43:24 -0700 Subject: [PATCH] Lowering for 'tosa.scatter' This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering: ``` func.func @tosa( %valuesIn : tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input : tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { %0 = "tosa.scatter"(%valuesIn, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>) return %0 : tensor<3x7x5xi32> } ``` translates to func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { %c0 = arith.constant 0 : index %c3 = arith.constant 3 : index %c1 = arith.constant 1 : index %c6 = arith.constant 6 : index %c2 = arith.constant 2 : index %c5 = arith.constant 5 : index %c0_0 = arith.constant 0 : index %c1_1 = arith.constant 1 : index %0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) { %1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) { %extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32> %2 = arith.index_cast %extracted : i32 to index %extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor %inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor into tensor<3x7x5xi32> scf.yield %inserted_slice : tensor<3x7x5xi32> } scf.yield %1 : tensor<3x7x5xi32> } return %0 : tensor<3x7x5xi32> } ``` We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons: - The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon). - The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified. Reviewed By: eric-k256 Differential Revision: https://reviews.llvm.org/D151117 --- mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 73 ++++++++++++++++++++++++- mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp | 2 +- mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir | 30 ++++++++++ 3 files changed, 102 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp index 8f10497..9139bf1 100644 --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp @@ -82,6 +82,75 @@ public: } }; +class ScatterOpConverter : public OpRewritePattern { + static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor, + int64_t dim) { + return builder.createOrFold(loc, tensor, dim); + } + + static Value createIndexConst(OpBuilder &builder, Location loc, + int64_t value) { + return builder.create(loc, value); + } + +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ScatterOp scatter, + PatternRewriter &rewriter) const final { + auto valuesIn = scatter.getValuesIn(); + auto indices = scatter.getIndices(); + auto input = scatter.getInput(); + auto loc = scatter.getLoc(); + + // N, W, C are chosen to match the TOSA spec + auto dimN = createTensorDim(rewriter, loc, input, 0); + auto dimW = createTensorDim(rewriter, loc, input, 1); + auto dimC = createTensorDim(rewriter, loc, input, 2); + + auto zero = createIndexConst(rewriter, loc, 0); + auto one = createIndexConst(rewriter, loc, 1); + + // Loop bounds + auto lbs = llvm::SmallVector(2, zero); + auto steps = llvm::SmallVector(2, one); + auto ubs = llvm::SmallVector{{dimN, dimW}}; + + auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange args) -> scf::ValueVector { + auto n = ivs[0]; + + // Read the index and cast it to index type + auto index = builder.create(loc, indices, ivs); + auto castIndex = builder.create( + loc, builder.getIndexType(), index); + + // Offset, sizes, and strides for the input tensor + auto inputOffset = llvm::to_vector(ivs); + inputOffset.push_back(zero); + + llvm::SmallVector sizes = {one, one, dimC}; + llvm::SmallVector strides = {one, one, one}; + + auto slice = builder.create( + loc, input, inputOffset, sizes, strides); + + // Insert the slice into the output accumulator tensor. + llvm::SmallVector outputOffset = {n, castIndex, zero}; + auto updated = builder.create( + loc, slice, args[0], outputOffset, sizes, strides); + + return {updated}; + }; + + auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps, + ValueRange{valuesIn}, buildBody); + rewriter.replaceOp(scatter, loops.results); + + return success(); + } +}; + class WhileOpConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -106,6 +175,6 @@ public: void mlir::tosa::populateTosaToSCFConversionPatterns( RewritePatternSet *patterns) { - patterns->add(patterns->getContext()); - patterns->add(patterns->getContext()); + patterns->add( + patterns->getContext()); } diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp index 759b730..d145350 100644 --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp @@ -37,7 +37,7 @@ public: RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); auto *op = getOperation(); diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir index 5993113..4f0e295 100644 --- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir +++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir @@ -56,3 +56,33 @@ func.func @if_test(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @scatter_test +// CHECK-SAME: ([[VALUES_IN:%.+]]: tensor<3x7x5xi32>, [[INDICES:%.+]]: tensor<3x6xi32>, [[INPUT:%.+]]: tensor<3x6x5xi32>) +func.func @scatter_test(%values_in: tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { + + // CHECK-DAG: [[C_0:%.+]] = arith.constant 0 : index + // CHECK-DAG: [[C_1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[C_2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[C_3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[C_5:%.+]] = arith.constant 5 : index + // CHECK-DAG: [[C_6:%.+]] = arith.constant 6 : index + // CHECK-DAG: [[C_0_0:%.+]] = arith.constant 0 : index + // CHECK-DAG: [[C_1_0:%.+]] = arith.constant 1 : index + // CHECK: [[RESULT_0:%.+]] = scf.for [[ITER_VAR_0:%.+]] = [[C_0_0]] to [[C_3]] step [[C_1_0]] iter_args([[ITER_ARG_0:%.+]] = [[VALUES_IN]]) -> (tensor<3x7x5xi32>) { + // CHECK: [[RESULT_1:%.+]] = scf.for [[ITER_VAR_1:%.+]] = [[C_0_0]] to [[C_6]] step [[C_1_0]] iter_args([[ITER_ARG_1:%.+]] = [[ITER_ARG_0]]) -> (tensor<3x7x5xi32>) { + // CHECK-DAG: [[EXTRACTED:%.+]] = tensor.extract [[INDICES]][[[ITER_VAR_0]], [[ITER_VAR_1]]] : tensor<3x6xi32> + // CHECK-DAG: [[EXTRACTED_CAST:%.+]] = arith.index_cast [[EXTRACTED]] : i32 to index + // CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<3x6x5xi32> to tensor + // CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor into tensor<3x7x5xi32> + // CHECK: scf.yield [[INSERTED_SLICE]] : tensor<3x7x5xi32> + // CHECK: } + // CHECK: scf.yield [[RESULT_1]] : tensor<3x7x5xi32> + // CHECK: } + %0 = "tosa.scatter"(%values_in, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>) + + // CHECK: return [[RESULT_0]] : tensor<3x7x5xi32> + return %0 : tensor<3x7x5xi32> +} -- 2.7.4