type == OperatorType::kTanh || type == OperatorType::kMul ||
type == OperatorType::kSpaceToDepth ||
type == OperatorType::kStridedSlice ||
- type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell;
+ type == OperatorType::kDepthToSpace ||
+ type == OperatorType::kLstmCell || type == OperatorType::kGather ||
+ type == OperatorType::kTranspose;
}
template <ArrayDataType A>
//
// Let us just guard this assumption by the following assertion:
for (const auto& input : op.inputs) {
- if (IsInputArray(*model, input)) {
- const auto& input_array = model->GetArray(input);
- CHECK(input_array.quantization_params);
+ const auto& input_array = model->GetArray(input);
+ if (IsInputArray(*model, input) &&
+ input_array.data_type == ArrayDataType::kFloat) {
+ CHECK(input_array.quantization_params)
+ << "Input array " << input << " is missing quantization_params";
}
}
if (!SupportsQuantization(op)) {
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <limits>
#include <memory>
#include <string>
#include <vector>
const auto it = model->operators.begin() + op_index;
auto* op = it->get();
if (op->fused_activation_function != FusedActivationFunctionType::kRelu &&
+ op->fused_activation_function != FusedActivationFunctionType::kRelu1 &&
op->fused_activation_function != FusedActivationFunctionType::kRelu6) {
return false;
}
}
const auto& quantization_params = output_array.GetQuantizationParams();
+ double clamp_min;
+ double clamp_max;
+ switch (op->fused_activation_function) {
+ case FusedActivationFunctionType::kRelu:
+ clamp_min = 0.0;
+ clamp_max = std::numeric_limits<double>::infinity();
+ break;
+ case FusedActivationFunctionType::kRelu1:
+ clamp_min = -1.0;
+ clamp_max = 1.0;
+ break;
+ case FusedActivationFunctionType::kRelu6:
+ clamp_min = 0.0;
+ clamp_max = 6.0;
+ break;
+ default:
+ LOG(FATAL) << "Unsupported fused activation type: "
+ << static_cast<int>(op->fused_activation_function);
+ return false;
+ }
+
bool has_nontrivial_min_bound = false;
bool has_nontrivial_max_bound = false;
- if (op->fused_activation_function == FusedActivationFunctionType::kRelu ||
- op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
- double lowest_representable_output =
- (0. - quantization_params.zero_point) * quantization_params.scale;
- if (lowest_representable_output < 0.) {
- has_nontrivial_min_bound = true;
- AddMessageF(
- "Quantized activation function is not trivial: "
- "the lowest representable output value %g"
- " less than the clamp min bound.",
- lowest_representable_output);
- }
+ double lowest_representable_output =
+ (0. - quantization_params.zero_point) * quantization_params.scale;
+ if (lowest_representable_output < clamp_min) {
+ has_nontrivial_min_bound = true;
+ AddMessageF(
+ "Quantized activation function is not trivial: "
+ "the lowest representable output value %g"
+ " less than the clamp min bound %g.",
+ lowest_representable_output, clamp_min);
}
- if (op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
- double highest_representable_output =
- (255. - quantization_params.zero_point) * quantization_params.scale;
- if (highest_representable_output > 6.) {
- has_nontrivial_max_bound = true;
- AddMessageF(
- "Quantized activation function is not trivial: "
- "the highest representable output value %g"
- " is greater than the clamp max bound.",
- highest_representable_output);
- }
+ double highest_representable_output =
+ (255. - quantization_params.zero_point) * quantization_params.scale;
+ if (highest_representable_output > clamp_max) {
+ has_nontrivial_max_bound = true;
+ AddMessageF(
+ "Quantized activation function is not trivial: "
+ "the highest representable output value %g"
+ " is greater than the clamp max bound %g.",
+ highest_representable_output, clamp_max);
}
if (has_nontrivial_min_bound || has_nontrivial_max_bound) {