-Miscellaneous code clean-up
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 30 Apr 2018 18:51:03 +0000 (11:51 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 30 Apr 2018 18:53:41 +0000 (11:53 -0700)
PiperOrigin-RevId: 194821201

tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
tensorflow/contrib/lite/toco/import_tensorflow.cc

index de6d888..bddb563 100644 (file)
@@ -79,8 +79,9 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
   const auto* max_op =
       op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1;
 
-  CHECK_EQ(min_op->inputs.size(), 2);
-  CHECK_EQ(max_op->inputs.size(), 2);
+  if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) {
+    return false;
+  }
   if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) {
     return false;
   }
index 8e6aaf5..1956ab2 100644 (file)
@@ -88,13 +88,11 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
   // At that point we know that none of the outputs is used, so we will
   // definitely remove the node and all its outputs.
 
-  // Remove any input array that is not used by anything else,
-  // and that is not the output of some other operator.
+  // Remove any input array that not the output of another op, and only used by
+  // this op.
   for (const auto& input : op->inputs) {
-    if (IsDiscardableArray(*model, input) &&
-        CountOpsWithInput(*model, input) == 1 &&
-        !GetOpWithOutput(*model, input)) {
-      model->EraseArray(input);
+    if (!GetOpWithOutput(*model, input)) {
+      DeleteArrayIfUsedOnce(input, model);
     }
   }
 
@@ -102,22 +100,9 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
   for (const auto& output : op->outputs) {
     // If the output array is the model's input array, don't remove that.
     // That's the case when cropping a model at a given --input_array.
-    if (!IsDiscardableArray(*model, output)) {
-      continue;
-    }
-    // Likewise, if the output array is a RNN state array, don't remove that.
-    bool found_output_as_rnn_state_array = false;
-    for (const auto& rnn_state : model->flags.rnn_states()) {
-      if (output == rnn_state.state_array()) {
-        found_output_as_rnn_state_array = true;
-        break;
-      }
-    }
-    if (found_output_as_rnn_state_array) {
-      continue;
+    if (IsDiscardableArray(*model, output)) {
+      model->EraseArray(output);
     }
-    // Generic case: do delete this output array.
-    model->EraseArray(output);
   }
   model->operators.erase(it);
   return true;
index ea0d6dc..69db194 100644 (file)
@@ -77,6 +77,13 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) {
     }
   }
 
+  int axis = op->axis;
+  if (axis < 0) {
+    // Handle negative axis
+    axis += model->GetArray(op->inputs[0]).shape().dims().size();
+  }
+  CHECK_EQ(axis, 0) << "Stacking only supported along 0th axis";
+
   CHECK(!output_array.buffer);
   switch (output_array.data_type) {
     case ArrayDataType::kFloat:
@@ -99,10 +106,7 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) {
 
   // Erase input arrays if no longer used
   for (const auto& input : op->inputs) {
-    if (IsDiscardableArray(*model, input) &&
-        CountOpsWithInput(*model, input) == 1) {
-      model->EraseArray(input);
-    }
+    toco::DeleteArrayIfUsedOnce(input, model);
   }
 
   // Erase the operator
index 2ed05cb..61e4c9d 100644 (file)
@@ -451,8 +451,18 @@ void ConvertConvOperator(const NodeDef& node,
   if (HasAttr(node, "dilations")) {
     const auto& dilations = GetListAttr(node, "dilations");
     CHECK_EQ(dilations.i_size(), 4);
-    CHECK_EQ(dilations.i(0), 1);
-    CHECK_EQ(dilations.i(3), 1);
+    CHECK_EQ(dilations.i(0), 1)
+        << "Can only import Conv ops with dilation along the height (1st) or "
+           "width (2nd) axis. TensorFlow op \""
+        << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
+        << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
+        << "].";
+    CHECK_EQ(dilations.i(3), 1)
+        << "Can only import Conv ops with dilation along the height (1st) or "
+           "width (2nd) axis. TensorFlow op \""
+        << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
+        << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
+        << "].";
     conv->dilation_height_factor = dilations.i(1);
     conv->dilation_width_factor = dilations.i(2);
   } else {