}
};
+class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
+ static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
+ int64_t dim) {
+ return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
+ }
+
+ static Value createIndexConst(OpBuilder &builder, Location loc,
+ int64_t value) {
+ return builder.create<arith::ConstantIndexOp>(loc, value);
+ }
+
+public:
+ using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
+ PatternRewriter &rewriter) const final {
+ auto valuesIn = scatter.getValuesIn();
+ auto indices = scatter.getIndices();
+ auto input = scatter.getInput();
+ auto loc = scatter.getLoc();
+
+ // N, W, C are chosen to match the TOSA spec
+ auto dimN = createTensorDim(rewriter, loc, input, 0);
+ auto dimW = createTensorDim(rewriter, loc, input, 1);
+ auto dimC = createTensorDim(rewriter, loc, input, 2);
+
+ auto zero = createIndexConst(rewriter, loc, 0);
+ auto one = createIndexConst(rewriter, loc, 1);
+
+ // Loop bounds
+ auto lbs = llvm::SmallVector<Value>(2, zero);
+ auto steps = llvm::SmallVector<Value>(2, one);
+ auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
+
+ auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+ ValueRange args) -> scf::ValueVector {
+ auto n = ivs[0];
+
+ // Read the index and cast it to index type
+ auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
+ auto castIndex = builder.create<arith::IndexCastOp>(
+ loc, builder.getIndexType(), index);
+
+ // Offset, sizes, and strides for the input tensor
+ auto inputOffset = llvm::to_vector(ivs);
+ inputOffset.push_back(zero);
+
+ llvm::SmallVector<Value> sizes = {one, one, dimC};
+ llvm::SmallVector<Value> strides = {one, one, one};
+
+ auto slice = builder.create<tensor::ExtractSliceOp>(
+ loc, input, inputOffset, sizes, strides);
+
+ // Insert the slice into the output accumulator tensor.
+ llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
+ auto updated = builder.create<tensor::InsertSliceOp>(
+ loc, slice, args[0], outputOffset, sizes, strides);
+
+ return {updated};
+ };
+
+ auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
+ ValueRange{valuesIn}, buildBody);
+ rewriter.replaceOp(scatter, loops.results);
+
+ return success();
+ }
+};
+
class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
public:
using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
void mlir::tosa::populateTosaToSCFConversionPatterns(
RewritePatternSet *patterns) {
- patterns->add<IfOpConverter>(patterns->getContext());
- patterns->add<WhileOpConverter>(patterns->getContext());
+ patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
+ patterns->getContext());
}
return %0 : tensor<f32>
}
+
+// -----
+
+// CHECK-LABEL: func @scatter_test
+// CHECK-SAME: ([[VALUES_IN:%.+]]: tensor<3x7x5xi32>, [[INDICES:%.+]]: tensor<3x6xi32>, [[INPUT:%.+]]: tensor<3x6x5xi32>)
+func.func @scatter_test(%values_in: tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
+
+ // CHECK-DAG: [[C_0:%.+]] = arith.constant 0 : index
+ // CHECK-DAG: [[C_1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[C_2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[C_3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[C_5:%.+]] = arith.constant 5 : index
+ // CHECK-DAG: [[C_6:%.+]] = arith.constant 6 : index
+ // CHECK-DAG: [[C_0_0:%.+]] = arith.constant 0 : index
+ // CHECK-DAG: [[C_1_0:%.+]] = arith.constant 1 : index
+ // CHECK: [[RESULT_0:%.+]] = scf.for [[ITER_VAR_0:%.+]] = [[C_0_0]] to [[C_3]] step [[C_1_0]] iter_args([[ITER_ARG_0:%.+]] = [[VALUES_IN]]) -> (tensor<3x7x5xi32>) {
+ // CHECK: [[RESULT_1:%.+]] = scf.for [[ITER_VAR_1:%.+]] = [[C_0_0]] to [[C_6]] step [[C_1_0]] iter_args([[ITER_ARG_1:%.+]] = [[ITER_ARG_0]]) -> (tensor<3x7x5xi32>) {
+ // CHECK-DAG: [[EXTRACTED:%.+]] = tensor.extract [[INDICES]][[[ITER_VAR_0]], [[ITER_VAR_1]]] : tensor<3x6xi32>
+ // CHECK-DAG: [[EXTRACTED_CAST:%.+]] = arith.index_cast [[EXTRACTED]] : i32 to index
+ // CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
+ // CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
+ // CHECK: scf.yield [[INSERTED_SLICE]] : tensor<3x7x5xi32>
+ // CHECK: }
+ // CHECK: scf.yield [[RESULT_1]] : tensor<3x7x5xi32>
+ // CHECK: }
+ %0 = "tosa.scatter"(%values_in, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>)
+
+ // CHECK: return [[RESULT_0]] : tensor<3x7x5xi32>
+ return %0 : tensor<3x7x5xi32>
+}