[mlir][tosa] Support non-batch dynamic dims for tosa.rescale to linalg
authornatashaknk <natashaknk@google.com>
Wed, 7 Sep 2022 00:13:12 +0000 (17:13 -0700)
committerRob Suderman <suderman@google.com>
Wed, 7 Sep 2022 00:30:02 +0000 (17:30 -0700)
Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D133383

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index 3a6d72b..f930a2b 100644 (file)
@@ -1126,11 +1126,12 @@ public:
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires scale32 for double_round to be true");
 
-    auto dynamicDimsOr =
-        checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
-    if (!dynamicDimsOr.has_value())
-      return failure();
-    SmallVector<Value> dynamicDims = dynamicDimsOr.value();
+    SmallVector<Value> dynDims;
+    for (int i = 0; i < outputTy.getRank(); i++) {
+      if (outputTy.isDynamicDim(i)) {
+        dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+      }
+    }
 
     // The shift and multiplier values.
     SmallVector<int32_t> multiplierValues;
@@ -1206,7 +1207,8 @@ public:
 
     // Construct the indexing maps needed for linalg.generic ops.
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, dynamicDims, outputTy.getShape(), outputTy.getElementType());
+        loc, ArrayRef<Value>({dynDims}), outputTy.getShape(),
+        outputTy.getElementType());
 
     auto linalgOp = rewriter.create<linalg::GenericOp>(
         loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
index 6a49ba5..2e4dfbf 100644 (file)
@@ -1004,8 +1004,8 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 
-// CHECK-LABEL: @rescale_i8_dyn
-func.func @rescale_i8_dyn(%arg0 : tensor<?x2xi8>) -> () {
+// CHECK-LABEL: @rescale_i8_dyn_batch
+func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2]
@@ -1020,6 +1020,23 @@ func.func @rescale_i8_dyn(%arg0 : tensor<?x2xi8>) -> () {
 
   return
 }
+
+// -----
+
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @rescale_dyn
+func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
+  // CHECK: %[[C1:.+]] = arith.constant 1
+  // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]]
+  // CHECK: %[[C2:.+]] = arith.constant 2
+  // CHECK: %[[DIM2:.+]] = tensor.dim %arg0, %[[C2]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, %[[DIM1]], %[[DIM2]], 32]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
+  %0 = "tosa.rescale"(%arg0) {double_round = true, input_zp = 0 : i32, multiplier = [1376784203 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [38 : i32]} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
+  return
+}
+
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>