[mlir][tosa] Add tosa.gather lowering to linalg.indexed_generic
authornatashaknk <natashaknk@google.com>
Sat, 24 Apr 2021 05:30:08 +0000 (22:30 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Sat, 24 Apr 2021 05:42:56 +0000 (22:42 -0700)
Lowering gather operation to linalg dialect.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D101200

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index 21ef0b8..042626e 100644 (file)
@@ -1781,6 +1781,59 @@ public:
   }
 };
 
+class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
+public:
+  using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tosa::GatherOp op, ArrayRef<Value> args,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto input = args[0];
+    auto indices = args[1];
+
+    auto inputTy = input.getType().cast<ShapedType>();
+    auto indicesTy = indices.getType().cast<ShapedType>();
+    auto resultTy = op.getType().cast<ShapedType>();
+
+    if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          op, "require input type to have static shape");
+
+    auto resultElementTy = resultTy.getElementType();
+
+    auto loc = op.getLoc();
+
+    auto initTensor =
+        rewriter
+            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
+                                          resultTy.getShape(), resultElementTy)
+            .result();
+
+    SmallVector<AffineMap, 2> affineMaps = {
+        AffineMap::get(
+            /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
+            {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
+            rewriter.getContext()),
+        rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+    auto genericOp = rewriter.create<linalg::IndexedGenericOp>(
+        loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
+        ValueRange{initTensor}, affineMaps,
+        getNParallelLoopsAttrs(resultTy.getRank()),
+        [&](OpBuilder &b, Location loc, ValueRange indices, ValueRange args) {
+          auto indexValue = args[0];
+          auto index0 = indices[0];
+          Value index1 = rewriter.create<IndexCastOp>(
+              loc, rewriter.getIndexType(), indexValue);
+          auto index2 = indices[2];
+          Value extract = rewriter.create<tensor::ExtractOp>(
+              loc, input, ValueRange{index0, index1, index2});
+          rewriter.create<linalg::YieldOp>(loc, extract);
+        });
+    rewriter.replaceOp(op, genericOp.getResult(0));
+    return success();
+  }
+};
+
 // Lowerings the TableOp to a series of gathers and numerica operations. This
 // includes interpolation between the high/low values. For the I8 varient, this
 // simplifies to a single gather operation.
@@ -2085,6 +2138,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       ArgMaxConverter,
       ConcatConverter,
       Conv2DConverter,
+      GatherConverter,
       PadConverter,
       ReshapeConverter,
       RescaleConverter,
index ff4dbf4..489bdd3 100644 (file)
@@ -833,6 +833,32 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @gather_float
+func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
+  // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xf32>)
+  // CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[ARG0:.+]]: i32, %[[ARG1:.+]]: f32)
+  // CHECK:   %[[CAST:.+]] = index_cast %[[ARG0]]
+  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xf32>
+  // CHECK:   linalg.yield %[[EXTRACT]]
+  %0 = "tosa.gather"(%arg0, %arg1)  : (tensor<2x3x2xf32>, tensor<2x3xi32>)  -> (tensor<2x3x2xf32>)
+  return
+}
+
+// CHECK-LABEL: @gather_int
+func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
+  // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xi32>)
+  // CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+  // CHECK:   %[[CAST:.+]] = index_cast %[[ARG0]]
+  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xi32>
+  // CHECK:   linalg.yield %[[EXTRACT]]
+  %0 = "tosa.gather"(%arg0, %arg1)  : (tensor<2x3x2xi32>, tensor<2x3xi32>)  -> (tensor<2x3x2xi32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @table8
 func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [6]