From 860d3811a9b2f3df0ac093d87832056fd7a19b87 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 9 Nov 2021 13:45:17 -0800 Subject: [PATCH] [mlir][tosa] Add lowering for tosa.pad with explicit value New TOSA pad operation can support explicitly specifying the pad value. Added lowering to linalg that uses the explicit value. Differential Revision: https://reviews.llvm.org/D113515 --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 36 +++++++++++++--------- .../Conversion/TosaToLinalg/tosa-to-linalg.mlir | 29 ++++++++++++++--- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index f66e924..5416526 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2381,20 +2381,30 @@ public: "Pad converter requires static shaped input / padding values."); } - Attribute constantAttr; - if (elementTy.isa()) - constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - else if (elementTy.isa() && !padOp.quantization_info()) - constantAttr = rewriter.getIntegerAttr(elementTy, 0); - else if (elementTy.isa() && padOp.quantization_info()) { - auto value = padOp.quantization_info().getValue().input_zp().getValue(); - constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); + // Setup the default constantAttr. + + Value padConstant; + + if (padOp.pad_const()) { + padConstant = rewriter.createOrFold( + loc, padOp.pad_const(), ValueRange({})); + } else { + Attribute constantAttr; + if (elementTy.isa()) + constantAttr = rewriter.getFloatAttr(elementTy, 0.0); + else if (elementTy.isa() && !padOp.quantization_info()) + constantAttr = rewriter.getIntegerAttr(elementTy, 0); + else if (elementTy.isa() && padOp.quantization_info()) { + auto value = padOp.quantization_info().getValue().input_zp().getValue(); + constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); + } + if (constantAttr) + padConstant = rewriter.create(loc, constantAttr); } - if (!constantAttr) { + if (!padConstant) { return rewriter.notifyMatchFailure( - padOp, - "tosa.pad to linalg lowering encountered an unknown element type"); + padOp, "tosa.pad was unable to determine the pad constant value."); } Value lowIndex = @@ -2424,10 +2434,8 @@ public: highValues.push_back(highVal); } - Value constant = rewriter.create(loc, constantAttr); - auto newPadOp = linalg::PadTensorOp::createPadScalarOp( - padOp.getType(), input, constant, lowValues, highValues, + padOp.getType(), input, padConstant, lowValues, highValues, /*nofold=*/false, loc, rewriter); rewriter.replaceOp(padOp, newPadOp.getResult()); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 51d3557..c7ddb6b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1187,11 +1187,11 @@ func @fully_connected_dyn(%arg0: tensor, %arg1: tensor<6x3xf32>, %arg2: func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> // TODO: Output contains multiple "arith.constant 1 : index". - // CHECK: [[INDEX1:%.+]] = arith.constant 1 : index - // CHECK: [[INDEX2:%.+]] = arith.constant 2 : index - // CHECK: [[INDEX3:%.+]] = arith.constant 3 : index - // CHECK: [[INDEX4:%.+]] = arith.constant 4 : index - // CHECK: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK: linalg.pad_tensor %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { // CHECK: ^bb0(%arg1: index, %arg2: index): // no predecessors // CHECK: linalg.yield [[CST]] @@ -1220,6 +1220,25 @@ func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { // ----- +func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { + %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // TODO: Output contains multiple "arith.constant 1 : index". + // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32 + // CHECK: linalg.pad_tensor %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: ^bb0(%arg1: index, %arg2: index): // no predecessors + // CHECK: linalg.yield [[CST]] + // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> + %1 = arith.constant dense<42.0> : tensor + %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor) -> (tensor<4x9xf32>) + return %2 : tensor<4x9xf32> +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> -- 2.7.4