ParseOpData returns kTfLiteError when error happens.
authorYu-Cheng Ling <ycling@google.com>
Tue, 10 Apr 2018 21:24:51 +0000 (14:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 10 Apr 2018 21:27:42 +0000 (14:27 -0700)
PiperOrigin-RevId: 192346224

tensorflow/contrib/lite/model.cc

index 13e5532..87af953 100644 (file)
@@ -261,13 +261,11 @@ T* MallocPOD() {
 // Parse the appropriate data out of the op.
 //
 // This handles builtin data explicitly as there are flatbuffer schemas.
-//
-// Returns memory that must be feed.
-//
-// TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program
-// crashes if error reporter is called.
-void* ParseOpData(const Operator* op, BuiltinOperator op_type,
-                  ErrorReporter* error_reporter) {
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
+// need to be released by calling `free`.`
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+                         ErrorReporter* error_reporter, void** builtin_data) {
   auto parse_padding = [](Padding padding) {
     switch (padding) {
       case Padding_SAME:
@@ -316,7 +314,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
     }
   };
 
-  void* builtin_data = nullptr;
+  *builtin_data = nullptr;
   switch (op_type) {
     case BuiltinOperator_CALL:
       // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
@@ -333,7 +331,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(conv_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_TANH:
@@ -358,10 +356,11 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
             ConvertTensorType(schema_params->out_data_type(),
                               &params->out_data_type, error_reporter);
         if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
-          break;
+          free(params);
+          return kTfLiteError;
         }
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_LSH_PROJECTION: {
@@ -370,7 +369,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
       if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
         params->type = parseLSHProjectionType(lshParams->type());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_AVERAGE_POOL_2D:
@@ -386,7 +385,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(pool_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_DEPTHWISE_CONV_2D: {
@@ -400,7 +399,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(conv_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_SVDF: {
@@ -410,7 +409,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(svdf_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
@@ -422,7 +421,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
             parse_activation(sequence_rnn_params->fused_activation_function());
         params->time_major = sequence_rnn_params->time_major();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_RNN: {
@@ -431,7 +430,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(rnn_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_EMBEDDING_LOOKUP:
@@ -444,7 +443,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
               op->builtin_options_as_EmbeddingLookupSparseOptions()) {
         params->combiner = parseCombinerType(embedding_params->combiner());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_FULLY_CONNECTED: {
@@ -455,7 +454,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation = parse_activation(
             fully_connected_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_HASHTABLE_LOOKUP:
@@ -466,7 +465,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
       if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
         params->beta = softmax_params->beta();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_CONCATENATION: {
@@ -478,7 +477,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
             parse_activation(concatenation_params->fused_activation_function());
         params->axis = concatenation_params->axis();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_MUL: {
@@ -487,7 +486,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(schema_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_ADD: {
@@ -496,7 +495,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(schema_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_DIV: {
@@ -505,7 +504,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(schema_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_SUB: {
@@ -514,7 +513,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(schema_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_L2_NORMALIZATION: {
@@ -523,7 +522,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->activation =
             parse_activation(schema_params->fused_activation_function());
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
@@ -535,7 +534,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->alpha = schema_params->alpha();
         params->beta = schema_params->beta();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
@@ -548,7 +547,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->cell_clip = lstm_params->cell_clip();
         params->proj_clip = lstm_params->proj_clip();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_RESIZE_BILINEAR: {
@@ -557,7 +556,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
               op->builtin_options_as_ResizeBilinearOptions()) {
         params->align_corners = schema_params->align_corners();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_PAD: {
@@ -571,7 +570,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
                                    params->shape, error_reporter);
         params->num_dimensions = new_shape->Length();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_SKIP_GRAM: {
@@ -581,7 +580,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->max_skip_size = skip_gram_params->max_skip_size();
         params->include_all_ngrams = skip_gram_params->include_all_ngrams();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_SPACE_TO_DEPTH: {
@@ -589,7 +588,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
       if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
         params->block_size = schema_params->block_size();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_GATHER: {
@@ -599,7 +598,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->axis = gather_params->axis();
       }
 
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_SPACE_TO_BATCH_ND: {
@@ -616,7 +615,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
       if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
         params->keep_dims = schema_params->keep_dims();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_SPLIT: {
@@ -624,7 +623,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
       if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
         params->num_splits = schema_params->num_splits();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_SQUEEZE: {
@@ -635,7 +634,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
                                    params->squeeze_dims, error_reporter);
         params->num_squeeze_dims = squeeze_dims->Length();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_STRIDED_SLICE: {
@@ -647,7 +646,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         params->new_axis_mask = schema_params->new_axis_mask();
         params->shrink_axis_mask = schema_params->shrink_axis_mask();
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_MAXIMUM:
@@ -660,16 +659,16 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
         ConvertTensorType(schema_params->output_type(), &params->output_type,
                           error_reporter);
       }
-      builtin_data = reinterpret_cast<void*>(params);
+      *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
     case BuiltinOperator_DELEGATE: {
       // TODO(ycling): Revisit when supporting saving delegated models.
       error_reporter->Report("DELEGATE op shouldn't exist in model.");
-      break;
+      return kTfLiteError;
     }
   }
-  return builtin_data;
+  return kTfLiteOk;
 }
 
 }  // namespace
@@ -709,10 +708,13 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
           reinterpret_cast<const char*>(op->custom_options()->data()),
           op->custom_options()->size(), nullptr, reg);
     } else {
+      void* builtin_data = nullptr;
+      TF_LITE_ENSURE_STATUS(
+          ParseOpData(op, op_type, error_reporter_, &builtin_data));
       interpreter->AddNodeWithParameters(
           FlatBufferIntArrayToVector(op->inputs()),
-          FlatBufferIntArrayToVector(op->outputs()), nullptr, 0,
-          ParseOpData(op, op_type, error_reporter_), reg);
+          FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data,
+          reg);
     }
   }