[mlir][tosa] Lowering of tosa.gather operations with dynamic dimensions
authorSpenser Bauman <sbauman@mathworks.com>
Mon, 10 Apr 2023 15:55:54 +0000 (15:55 +0000)
committerRobert Suderman <suderman@google.com>
Mon, 10 Apr 2023 15:56:57 +0000 (15:56 +0000)
The existing TOSA->Linalg lowering of tosa.gather only supports gathers
with either a static shape or a single dynamic batch dimension.
This change extends support to arbitrary number of dynamic dimensions on
both the values and indices of the gather operation.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D147810

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index be24f5e..b2e59b2 100644 (file)
@@ -1829,19 +1829,19 @@ public:
     auto input = adaptor.getOperands()[0];
     auto indices = adaptor.getOperands()[1];
 
+    auto valuesTy =
+        op.getValues().getType().dyn_cast_or_null<RankedTensorType>();
     auto resultTy = op.getType().cast<ShapedType>();
 
-    auto dynamicDimsOr = checkHasDynamicBatchDims(
-        rewriter, op, {input, indices, op.getOutput()});
-    if (!dynamicDimsOr.has_value())
-      return rewriter.notifyMatchFailure(
-          op, "tosa.gather currently only supports dynamic batch dimensions");
-    SmallVector<Value> dynamicDims = *dynamicDimsOr;
+    if (!valuesTy)
+      return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
+
+    auto dynamicDims = inferDynamicDimsForGather(
+        rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
 
     auto resultElementTy = resultTy.getElementType();
 
     auto loc = op.getLoc();
-
     auto emptyTensor =
         rewriter
             .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
@@ -1872,6 +1872,24 @@ public:
     rewriter.replaceOp(op, genericOp.getResult(0));
     return success();
   }
+
+  static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
+                                                            Location loc,
+                                                            Value values,
+                                                            Value indices) {
+    llvm::SmallVector<Value> results;
+
+    auto addDynamicDimension = [&](Value source, int64_t dim) {
+      auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
+      if (auto dimValue = dynamicDim.value().dyn_cast<Value>())
+        results.push_back(dimValue);
+    };
+
+    addDynamicDimension(values, 0);
+    addDynamicDimension(indices, 1);
+    addDynamicDimension(values, 2);
+    return results;
+  }
 };
 
 // Lowerings the TableOp to a series of gathers and numerica operations. This
index 476131b..e9e9037 100644 (file)
@@ -1267,6 +1267,30 @@ func.func @gather_float_dyn(%arg0: tensor<?x3x2xf32>, %arg1: tensor<?x3xi32>) ->
 
 // -----
 
+// CHECK-LABEL: @gather_float_all_dynamic
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
+func.func @gather_float_all_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi32>) -> () {
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+  // CHECK: %[[C1:.+]] = arith.constant 1
+  // CHECK: %[[INDEX:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+  // CHECK: %[[C2:.+]] = arith.constant 2
+  // CHECK: %[[CHANNEL:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[INDEX]], %[[CHANNEL]])
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xi32>) outs(%[[INIT]] : tensor<?x?x?xf32>)
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: f32)
+  // CHECK:   %[[IDX0:.+]] = linalg.index 0
+  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[BBARG0]]
+  // CHECK:   %[[IDX2:.+]] = linalg.index 2
+  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<?x?x?xf32>
+  // CHECK:   linalg.yield %[[EXTRACT]]
+  %0 = "tosa.gather"(%arg0, %arg1)  : (tensor<?x?x?xf32>, tensor<?x?xi32>)  -> (tensor<?x?x?xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @gather_int
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
 // CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]