From b45476c94ce8ea94e2ad4d93ceda00eb4078e682 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 24 Sep 2021 16:57:46 -0400 Subject: [PATCH] [mlir][tosa] Do not fold transpose with quantized types For such cases, the type of the constant DenseElementsAttr is different from the transpose op return type. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D110446 --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 9 ++++++--- mlir/test/Dialect/Tosa/canonicalize.mlir | 11 +++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index a51d02e..780cd03 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -165,6 +165,12 @@ struct ConstantTransposeOptimization LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override { + auto outputType = op.getType().cast(); + ArrayRef outputShape = outputType.getShape(); + // TOSA supports quantized types. + if (!outputType.getElementType().isIntOrIndexOrFloat()) + return failure(); + DenseElementsAttr inputValues; if (!matchPattern(op.input1(), m_Constant(&inputValues))) return failure(); @@ -184,9 +190,6 @@ struct ConstantTransposeOptimization ArrayRef inputShape = inputType.getShape(); int64_t numElements = inputType.getNumElements(); - auto outputType = op.getType().cast(); - ArrayRef outputShape = outputType.getShape(); - SmallVector outputValues; outputValues.resize(numElements); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index a3ede85..8e8206f 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -314,3 +314,14 @@ func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) { %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> } + +// ----- + +// CHECK-LABEL: @transpose_nofold_quantized_types +func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> { + %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32> + %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8> + // CHECK: tosa.transpose + %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> + return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> +} -- 2.7.4