From 4c48f7e29b7014af5ba8292a508b8386e6b00f03 Mon Sep 17 00:00:00 2001 From: natashaknk Date: Wed, 6 Oct 2021 10:29:37 -0700 Subject: [PATCH] [mlir][tosa] Create basic dynamic shape support for several ops. Transpose, Matmul and Fully-connected dynamic shape support Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D111167 --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 81 ++++++++++++--- .../Conversion/TosaToLinalg/tosa-to-linalg.mlir | 111 ++++++++++++++++++++- 2 files changed, 175 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 77e4c26..f24a849 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -91,6 +91,14 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef pad, .result(); } +static SmallVector filterDynamicDims(SmallVector dynDims) { + SmallVector filteredDims; + for (auto dim : dynDims) + if (dim) + filteredDims.push_back(dim); + return filteredDims; +} + static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef resultTypes, @@ -690,10 +698,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, } } - SmallVector filteredDims; - for (auto dim : dynDims) - if (dim) - filteredDims.push_back(dim); + SmallVector filteredDims = filterDynamicDims(dynDims); for (auto result : results) { auto resultTy = result.getType().template cast(); @@ -1355,10 +1360,31 @@ public: auto outputTy = op.getType().cast(); auto outputElementTy = outputTy.getElementType(); + + auto firstOperandTy = op->getOperand(0).getType().cast(); + auto secondOperandTy = op->getOperand(1).getType().cast(); + + SmallVector dynDims; + dynDims.resize(op->getResult(0).getType().cast().getRank()); + + if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) { + dynDims[0] = rewriter.create(loc, op->getOperand(0), 0); + } + + if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) { + dynDims[1] = rewriter.create(loc, op->getOperand(0), 1); + } + + if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) { + dynDims[2] = rewriter.create(loc, op->getOperand(1), 2); + } + + SmallVector filteredDims = filterDynamicDims(dynDims); + auto zeroAttr = rewriter.getZeroAttr(outputElementTy); Value zero = rewriter.create(loc, zeroAttr); auto initTensor = rewriter.create( - loc, outputTy.getShape(), outputTy.getElementType()); + loc, filteredDims, outputTy.getShape(), outputTy.getElementType()); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); if (!op.quantization_info()) { @@ -1393,14 +1419,29 @@ public: Location loc = op.getLoc(); auto outputTy = op.getType().cast(); auto input = op.input(); - auto weight = op.weight(); + auto inputTy = input.getType().cast(); + auto bias = op.bias(); + auto weight = op.weight(); auto weightTy = weight.getType().cast(); auto weightShape = weightTy.getShape(); auto outputETy = outputTy.getElementType(); + SmallVector dynDims; + dynDims.resize(op->getResult(0).getType().cast().getRank()); + + if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) { + dynDims[0] = rewriter.create(loc, input, 0); + } + + if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) { + dynDims[1] = rewriter.create(loc, weight, 0); + } + + SmallVector filteredDims = filterDynamicDims(dynDims); + // Creating maps for the output of MatMul and the bias SmallVector indexingMaps; @@ -1413,7 +1454,7 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); auto initTensor = rewriter.create( - loc, outputTy.getShape(), outputTy.getElementType()); + loc, filteredDims, outputTy.getShape(), outputTy.getElementType()); // When quantized, the input elemeny type is not the same as the output Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy); @@ -1435,7 +1476,8 @@ public: auto biasInitTensor = rewriter - .create(loc, outputTy.getShape(), outputETy) + .create(loc, filteredDims, + outputTy.getShape(), outputETy) ->getResults(); if (!op.quantization_info()) { @@ -1614,20 +1656,29 @@ public: return failure(); } + auto loc = op.getLoc(); + auto input = op->getOperand(0); auto resultTy = op.getType().cast(); - if (!resultTy.hasStaticShape()) - return failure(); + + SmallVector dynDims; + dynDims.resize(op->getResult(0).getType().cast().getRank()); SmallVector inputExprs; inputExprs.resize(resultTy.getRank()); + auto operandTy = input.getType().cast(); for (auto permutation : llvm::enumerate(perms.getValues())) { - inputExprs[permutation.value().getZExtValue()] = - rewriter.getAffineDimExpr(permutation.index()); + auto index = permutation.index(); + auto value = permutation.value().getZExtValue(); + if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) { + dynDims[value] = rewriter.create(loc, input, index); + } + inputExprs[value] = rewriter.getAffineDimExpr(index); } + SmallVector filteredDims = filterDynamicDims(dynDims); + auto initTensor = rewriter.create( - op.getLoc(), ArrayRef({}), resultTy.getShape(), - resultTy.getElementType()); + loc, filteredDims, resultTy.getShape(), resultTy.getElementType()); SmallVector affineMaps = { AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, @@ -1638,7 +1689,7 @@ public: op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(op.getLoc(), *args.begin()); + nestedBuilder.create(loc, *args.begin()); }); return success(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index ba9b0d2..1c81a2a 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -592,6 +592,48 @@ func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () { // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: @test_transpose_dyn +// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>) +func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () { + %0 = constant dense<[1, 3, 0, 2]> : tensor<4xi32> + // CHECK: %[[C1:.+]] = constant 1 + // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]], 4, 1, 3] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs([[OUT:%.+]] : tensor) + // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32) + // CHECK: linalg.yield [[ARG1]] + // CHECK: } + %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> (tensor) + return +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @test_transpose_dyn +// CHECK-SAME: (%[[ARG0:.+]]: tensor) +func @test_transpose_dyn_multiple(%arg0: tensor) -> () { + %0 = constant dense<[1, 0]> : tensor<2xi32> + // CHECK: %[[C0:.+]] = constant 0 + // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[C1:.+]] = constant 1 + // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM1]], %[[DIM0]]] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor) outs([[OUT:%.+]] : tensor) + // CHECK: ^bb0([[ARG1:%.+]]: f32, [[ARG2:%.+]]: f32) + // CHECK: linalg.yield [[ARG1]] + // CHECK: } + %1 = "tosa.transpose"(%arg0, %0) : (tensor, tensor<2xi32>) -> (tensor) + return +} + +// ----- + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> @@ -987,7 +1029,7 @@ func @tile(%arg0 : tensor<2x3xi8>) -> () { // CHECK-LABEL: @matmul -func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>, %arg2: tensor<1x6xf32>) -> (tensor<1x5x6xf32>) { +func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) { // CHECK: [[C0:%.+]] = constant 0 // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6] // CHECK: [[FILLED:%.+]] = linalg.fill([[C0]], [[INIT]]) : f32, tensor<1x5x6xf32> -> tensor<1x5x6xf32> @@ -1013,6 +1055,46 @@ func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> (ten // ----- +// CHECK-LABEL: @matmul_dyn_batch +func @matmul_dyn_batch(%arg0: tensor, %arg1: tensor) -> (tensor) { + // CHECK: %[[C0:.+]] = constant 0 + // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[C0_0:.+]] = constant 0 + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]], 5, 6] + // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0_0]], %[[INIT]]) : f32, tensor -> tensor + // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) outs(%[[FILLED]] : tensor) -> tensor + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @matmul_dyn_independent_dim +func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x?xf32>) -> (tensor<1x5x?xf32>) { + // CHECK: %[[C2:.+]] = constant 2 + // CHECK: %[[DIM:.+]] = tensor.dim %arg1, %[[C2]] + // CHECK: %[[C0:.+]] = constant 0 + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, %[[DIM]]] + // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x5x?xf32> -> tensor<1x5x?xf32> + // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32> + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> (tensor<1x5x?xf32>) + return %0 : tensor<1x5x?xf32> +} + +// ----- + +// CHECK-LABEL: @matmul_dyn_independent_dim +func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>) { + // CHECK: %[[C0:.+]] = constant 0 + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, 6] + // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x5x6xf32> -> tensor<1x5x6xf32> + // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>) + return %0 : tensor<1x5x6xf32> +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)> @@ -1055,7 +1137,7 @@ func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %a // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8): // CHECK: linalg.yield [[IN]] : i8 // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] - // CHECK: [[ONE:%.+]] = constant 1 + // CHECK: [[ONE:%.+]] = constant 1 // CHECK: [[TWO:%.+]] = constant 2 // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32> // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]] @@ -1068,6 +1150,31 @@ func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %a // ----- +// CHECK-LABEL: @fully_connected_dyn +func @fully_connected_dyn(%arg0: tensor, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor) { + // CHECK: %[[C0:.+]] = constant 0 + // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INITT:.+]] = linalg.init_tensor [%[[DIM]], 6] + // CHECK: %[[ZERO:.+]] = constant 0 + // CHECK: %[[FILL:.+]] = linalg.fill(%[[ZERO]], %[[INITT]]) + // CHECK: %[[PERM:.+]] = constant dense<[1, 0]> + // CHECK: %[[INITT:.+]] = linalg.init_tensor [3, 6] + // CHECK: %[[TRANSPOSE:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs(%[[INITT]] : tensor<3x6xf32>) { + // CHECK: ^bb0(%[[IN:.+]]: f32, %[[UNUSED:.+]]: f32): + // CHECK: linalg.yield %[[IN]] : f32 + // CHECK: %[[INITB:.+]] = linalg.init_tensor [%[[DIM]], 6] + // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor, tensor<3x6xf32>) outs(%[[FILL]] : tensor) -> tensor + // CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor) outs(%[[INITB]] : tensor) { + // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + // CHECK: %[[ADD:.+]] = addf %arg3, %arg4 : f32 + // CHECK: linalg.yield %[[ADD]] : f32 + + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor, tensor<6x3xf32>, tensor<6xf32>) -> (tensor) + return %0 : tensor +} + +// ----- + func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> // TODO: Output contains multiple "constant 1 : index". -- 2.7.4