Supporting quantization of Gather ops and removal of trivial Relu1s when quantized.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 17:12:22 +0000 (10:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 17:16:20 +0000 (10:16 -0700)
PiperOrigin-RevId: 188738133

tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc

index 48a67ca..5cc82da 100644 (file)
@@ -330,6 +330,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
     case OperatorType::kSqueeze:
     case OperatorType::kTensorFlowReshape:
     case OperatorType::kPad:
+    case OperatorType::kGather:
+    case OperatorType::kTranspose:
       changed = HardcodeMinMaxFromFirstInput(model, op);
       break;
 
index 05686ce..ad3f052 100644 (file)
@@ -50,7 +50,9 @@ bool SupportsQuantization(const Operator& op) {
          type == OperatorType::kTanh || type == OperatorType::kMul ||
          type == OperatorType::kSpaceToDepth ||
          type == OperatorType::kStridedSlice ||
-         type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell;
+         type == OperatorType::kDepthToSpace ||
+         type == OperatorType::kLstmCell || type == OperatorType::kGather ||
+         type == OperatorType::kTranspose;
 }
 
 template <ArrayDataType A>
@@ -511,9 +513,11 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
   //
   // Let us just guard this assumption by the following assertion:
   for (const auto& input : op.inputs) {
-    if (IsInputArray(*model, input)) {
-      const auto& input_array = model->GetArray(input);
-      CHECK(input_array.quantization_params);
+    const auto& input_array = model->GetArray(input);
+    if (IsInputArray(*model, input) &&
+        input_array.data_type == ArrayDataType::kFloat) {
+      CHECK(input_array.quantization_params)
+          << "Input array " << input << " is missing quantization_params";
     }
   }
   if (!SupportsQuantization(op)) {
index 28f76c9..9b65fea 100644 (file)
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include <limits>
 #include <memory>
 #include <string>
 #include <vector>
@@ -30,6 +31,7 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
   const auto it = model->operators.begin() + op_index;
   auto* op = it->get();
   if (op->fused_activation_function != FusedActivationFunctionType::kRelu &&
+      op->fused_activation_function != FusedActivationFunctionType::kRelu1 &&
       op->fused_activation_function != FusedActivationFunctionType::kRelu6) {
     return false;
   }
@@ -42,33 +44,49 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
   }
   const auto& quantization_params = output_array.GetQuantizationParams();
 
+  double clamp_min;
+  double clamp_max;
+  switch (op->fused_activation_function) {
+    case FusedActivationFunctionType::kRelu:
+      clamp_min = 0.0;
+      clamp_max = std::numeric_limits<double>::infinity();
+      break;
+    case FusedActivationFunctionType::kRelu1:
+      clamp_min = -1.0;
+      clamp_max = 1.0;
+      break;
+    case FusedActivationFunctionType::kRelu6:
+      clamp_min = 0.0;
+      clamp_max = 6.0;
+      break;
+    default:
+      LOG(FATAL) << "Unsupported fused activation type: "
+                 << static_cast<int>(op->fused_activation_function);
+      return false;
+  }
+
   bool has_nontrivial_min_bound = false;
   bool has_nontrivial_max_bound = false;
 
-  if (op->fused_activation_function == FusedActivationFunctionType::kRelu ||
-      op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
-    double lowest_representable_output =
-        (0. - quantization_params.zero_point) * quantization_params.scale;
-    if (lowest_representable_output < 0.) {
-      has_nontrivial_min_bound = true;
-      AddMessageF(
-          "Quantized activation function is not trivial: "
-          "the lowest representable output value %g"
-          " less than the clamp min bound.",
-          lowest_representable_output);
-    }
+  double lowest_representable_output =
+      (0. - quantization_params.zero_point) * quantization_params.scale;
+  if (lowest_representable_output < clamp_min) {
+    has_nontrivial_min_bound = true;
+    AddMessageF(
+        "Quantized activation function is not trivial: "
+        "the lowest representable output value %g"
+        " less than the clamp min bound %g.",
+        lowest_representable_output, clamp_min);
   }
-  if (op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
-    double highest_representable_output =
-        (255. - quantization_params.zero_point) * quantization_params.scale;
-    if (highest_representable_output > 6.) {
-      has_nontrivial_max_bound = true;
-      AddMessageF(
-          "Quantized activation function is not trivial: "
-          "the highest representable output value %g"
-          " is greater than the clamp max bound.",
-          highest_representable_output);
-    }
+  double highest_representable_output =
+      (255. - quantization_params.zero_point) * quantization_params.scale;
+  if (highest_representable_output > clamp_max) {
+    has_nontrivial_max_bound = true;
+    AddMessageF(
+        "Quantized activation function is not trivial: "
+        "the highest representable output value %g"
+        " is greater than the clamp max bound %g.",
+        highest_representable_output, clamp_max);
   }
 
   if (has_nontrivial_min_bound || has_nontrivial_max_bound) {