tensor = Subtract(tensor, input_zp);
}
- // 3) Multiply the integer multiplier
- if (left_shift != 0) {
- tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
- }
- // Perform the multiplication in higher precision.
- // The scalar is a fixed point value of int32 where the decimal point is
- // between bits 31 and 30. After multiplying with input_tensor, the result is
- // in int64 where the decimal point is sitting between bits 31 and 30 (from
- // the right, rightmost bit is bit 0). The computation is performed in higher
- // precision to avoid overflow in multiplying two int32 values.
- Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
- auto multiplied_t = Multiply(tensor, scalar);
+ // If the input and output scales are same, we can skip the fixed point multiplication.
+ auto scaled_int64_t = tensor;
+ if (param->input_scale != param->output_scale) {
+ // 3) Multiply the integer multiplier
+ if (left_shift != 0) {
+ tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
+ }
+ // Perform the multiplication in higher precision.
+ // The scalar is a fixed point value of int32 where the decimal point is
+ // between bits 31 and 30. After multiplying with input_tensor, the result is
+ // in int64 where the decimal point is sitting between bits 31 and 30 (from
+ // the right, rightmost bit is bit 0). The computation is performed in higher
+ // precision to avoid overflow in multiplying two int32 values.
+ Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
+ auto multiplied_t = Multiply(tensor, scalar);
- // 4) Find the rounding scalar. This depends on where the final decimal point
- // sits. As we will be right shifting the multiplied_t, we need to first
- // calculate the total_right_shift.
- int total_right_shift = right_shift + 31;
- int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
+ // 4) Find the rounding scalar. This depends on where the final decimal point
+ // sits. As we will be right shifting the multiplied_t, we need to first
+ // calculate the total_right_shift.
+ int total_right_shift = right_shift + 31;
+ int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
- tensor = multiplied_t;
- Expr round_scalar;
- if (param->rounding == "UPWARD") {
- round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
- } else if (param->rounding == "TONEAREST") {
- auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
- auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
- auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
- auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
+ tensor = multiplied_t;
+ Expr round_scalar;
+ if (param->rounding == "UPWARD") {
+ round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
+ } else if (param->rounding == "TONEAREST") {
+ auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
+ auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
+ auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
+ auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
- auto zero = MakeConstantScalar(hp_dtype, 0);
- auto zero_t = Full(zero, input_shape, hp_dtype);
- round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
- }
- // Add the rounding scalar.
- tensor = Add(tensor, round_scalar);
+ auto zero = MakeConstantScalar(hp_dtype, 0);
+ auto zero_t = Full(zero, input_shape, hp_dtype);
+ round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
+ }
+ // Add the rounding scalar.
+ tensor = Add(tensor, round_scalar);
- // 5) Simply right shift the result to get the final output.
- auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
+ // 5) Simply right shift the result to get the final output.
+ scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
+ }
// 6) Add the output zero point.
- auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
- auto shifted_int64_t = Add(output_zp, scaled_int64_t);
+ auto shifted_int64_t = scaled_int64_t;
+ if (param->output_zero_point != 0) {
+ auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
+ shifted_int64_t = Add(output_zp, scaled_int64_t);
+ }
// 7) Clip to the out_dtype min/max.
auto q_min = GetQmin(out_dtype);