auto input = adaptor.getOperands()[0];
auto indices = adaptor.getOperands()[1];
+ auto valuesTy =
+ op.getValues().getType().dyn_cast_or_null<RankedTensorType>();
auto resultTy = op.getType().cast<ShapedType>();
- auto dynamicDimsOr = checkHasDynamicBatchDims(
- rewriter, op, {input, indices, op.getOutput()});
- if (!dynamicDimsOr.has_value())
- return rewriter.notifyMatchFailure(
- op, "tosa.gather currently only supports dynamic batch dimensions");
- SmallVector<Value> dynamicDims = *dynamicDimsOr;
+ if (!valuesTy)
+ return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
+
+ auto dynamicDims = inferDynamicDimsForGather(
+ rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
auto resultElementTy = resultTy.getElementType();
auto loc = op.getLoc();
-
auto emptyTensor =
rewriter
.create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
rewriter.replaceOp(op, genericOp.getResult(0));
return success();
}
+
+ static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
+ Location loc,
+ Value values,
+ Value indices) {
+ llvm::SmallVector<Value> results;
+
+ auto addDynamicDimension = [&](Value source, int64_t dim) {
+ auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
+ if (auto dimValue = dynamicDim.value().dyn_cast<Value>())
+ results.push_back(dimValue);
+ };
+
+ addDynamicDimension(values, 0);
+ addDynamicDimension(indices, 1);
+ addDynamicDimension(values, 2);
+ return results;
+ }
};
// Lowerings the TableOp to a series of gathers and numerica operations. This
// -----
+// CHECK-LABEL: @gather_float_all_dynamic
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]
+func.func @gather_float_all_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi32>) -> () {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[INDEX:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+ // CHECK: %[[C2:.+]] = arith.constant 2
+ // CHECK: %[[CHANNEL:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[INDEX]], %[[CHANNEL]])
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xi32>) outs(%[[INIT]] : tensor<?x?x?xf32>)
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: f32)
+ // CHECK: %[[IDX0:.+]] = linalg.index 0
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[BBARG0]]
+ // CHECK: %[[IDX2:.+]] = linalg.index 2
+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<?x?x?xf32>
+ // CHECK: linalg.yield %[[EXTRACT]]
+ %0 = "tosa.gather"(%arg0, %arg1) : (tensor<?x?x?xf32>, tensor<?x?xi32>) -> (tensor<?x?x?xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: @gather_int
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]