}
// tosa::NegateOp
- if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
+
+ if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
+ !cast<tosa::NegateOp>(op).quantization_info()) {
auto constant =
- rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, -1));
- return rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], constant);
+ rewriter.create<ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+ return rewriter.create<SubIOp>(loc, resultTypes, constant, args[0]);
}
- if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
- return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
+ if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
+ cast<tosa::NegateOp>(op).quantization_info()) {
+ auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
+ int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
+ int64_t inZp =
+ quantizationInfo.getValue().input_zp().getValue().getSExtValue();
+ int64_t outZp =
+ quantizationInfo.getValue().output_zp().getValue().getSExtValue();
+
+ // Compute the maximum value that can occur in the intermediate buffer.
+ int64_t zpAdd = inZp + outZp;
+ int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+ std::abs(zpAdd) + 1;
+
+ // Convert that maximum value into the maximum bitwidth needed to represent
+ // it. We assume 48-bit numbers may be supported further in the pipeline.
+ int intermediateBitWidth = 64;
+ if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+ intermediateBitWidth = 16;
+ } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+ intermediateBitWidth = 32;
+ } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+ intermediateBitWidth = 48;
+ }
+
+ Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ Value zpAddValue = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+
+ // The negation can be applied by doing:
+ // outputValue = inZp + outZp - inputValue
+ auto ext = rewriter.create<SignExtendIOp>(loc, intermediateType, args[0]);
+ auto sub = rewriter.create<SubIOp>(loc, zpAddValue, ext);
+
+ // Clamp to the negation range.
+ auto min = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ intermediateType,
+ APInt::getSignedMinValue(inputBitWidth).getSExtValue()));
+ auto max = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ intermediateType,
+ APInt::getSignedMaxValue(inputBitWidth).getSExtValue()));
+ auto clamp = clampHelper<mlir::CmpIOp>(loc, sub, min, max,
+ CmpIPredicate::slt, rewriter);
+
+ // Truncate to the final value.
+ return rewriter.create<TruncateIOp>(loc, elementTy, clamp);
+ }
// tosa::BitwiseAndOp
if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
%3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
- // CHECK: muli
+ // CHECK: [[ZERO:%.+]] = constant 0
+ // CHECK: subi [[ZERO]], %arg1
%4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// -----
+// CHECK-LABEL: @test_negate_quantized
+func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: [[ZERO:%.+]] = constant 0
+ // CHECK: [[EXT:%.+]] = sexti %arg1 : i8 to i16
+ // CHECK: [[SUB:%.+]] = subi [[ZERO]], [[EXT]]
+ // CHECK: [[MIN:%.+]] = constant -128
+ // CHECK: [[MAX:%.+]] = constant 127
+ // CHECK: [[PRED1:%.+]] = cmpi slt, [[SUB]], [[MIN]]
+ // CHECK: [[LBOUND:%.+]] = select [[PRED1]], [[MIN]], [[SUB]]
+ // CHECK: [[PRED2:%.+]] = cmpi slt, [[MAX]], [[SUB]]
+ // CHECK: [[UBOUND:%.+]] = select [[PRED2]], [[MAX]], [[LBOUND]]
+ // CHECK: [[TRUNC:%.+]] = trunci [[UBOUND]]
+ // CHECK: linalg.yield [[TRUNC]]
+ %0 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 0 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8>
+
+ // CHECK: linalg.generic
+ // CHECK: [[EXT:%.+]] = sexti %arg1 : i8 to i16
+ %1 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 32639 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8>
+
+ // CHECK: linalg.generic
+ // CHECK: [[EXT:%.+]] = sexti %arg1 : i8 to i32
+ %2 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 32640 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8>
+
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @test_reshape_downrank
func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {