return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
- auto dynamicDimsOr =
- checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
- if (!dynamicDimsOr.has_value())
- return failure();
- SmallVector<Value> dynamicDims = dynamicDimsOr.value();
+ SmallVector<Value> dynDims;
+ for (int i = 0; i < outputTy.getRank(); i++) {
+ if (outputTy.isDynamicDim(i)) {
+ dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ }
+ }
// The shift and multiplier values.
SmallVector<int32_t> multiplierValues;
// Construct the indexing maps needed for linalg.generic ops.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
- loc, dynamicDims, outputTy.getShape(), outputTy.getElementType());
+ loc, ArrayRef<Value>({dynDims}), outputTy.getShape(),
+ outputTy.getElementType());
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: @rescale_i8_dyn
-func.func @rescale_i8_dyn(%arg0 : tensor<?x2xi8>) -> () {
+// CHECK-LABEL: @rescale_i8_dyn_batch
+func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2]
return
}
+
+// -----
+
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @rescale_dyn
+func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]]
+ // CHECK: %[[C2:.+]] = arith.constant 2
+ // CHECK: %[[DIM2:.+]] = tensor.dim %arg0, %[[C2]]
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, %[[DIM1]], %[[DIM2]], 32]
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
+ %0 = "tosa.rescale"(%arg0) {double_round = true, input_zp = 0 : i32, multiplier = [1376784203 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [38 : i32]} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
+ return
+}
+
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>