[pytorch] Do first run with try/catch
authorParichay Kapoor <pk.kapoor@samsung.com>
Tue, 22 Oct 2019 04:27:04 +0000 (13:27 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Thu, 24 Oct 2019 04:28:53 +0000 (13:28 +0900)
First run after configuring the tensor filter with pytorch framework
is now run with a try/catch.

Check #1809 for more details.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
ext/nnstreamer/tensor_filter/tensor_filter_pytorch_core.cc
ext/nnstreamer/tensor_filter/tensor_filter_pytorch_core.h

index 8e2b807..8a56f48 100644 (file)
@@ -44,6 +44,7 @@ TorchCore::TorchCore (const char *_model_path)
 {
   model_path = _model_path;
   configured = false;
+  first_run = true;
 
   gst_tensors_info_init (&inputTensorMeta);
   gst_tensors_info_init (&outputTensorMeta);
@@ -77,6 +78,8 @@ TorchCore::init (const GstTensorFilterProperties * prop,
     g_critical ("Failed to load model\n");
     return -1;
   }
+
+  first_run = true;
   return 0;
 }
 
@@ -317,9 +320,10 @@ TorchCore::processIValue (torch::jit::IValue value, GstTensorMemory * output)
  * @param[in] input : The array of input tensors
  * @param[out]  output : The array of output tensors
  * @return 0 if OK. non-zero if error.
- *         -1 if the model does not work properly.
+ *         -1 if the input properties are incompatible.
  *         -2 if the output properties are different with model.
  *         -3 if the output is neither a list nor a tensor.
+ *         -4 if running the model failed.
  */
 int
 TorchCore::invoke (const GstTensorMemory * input, GstTensorMemory * output)
@@ -354,7 +358,27 @@ TorchCore::invoke (const GstTensorMemory * input, GstTensorMemory * output)
     input_feeds.emplace_back (tensor);
   }
 
-  output_value = model->forward (input_feeds);
+  /**
+   * As the input information has not been verified, the first run for the model
+   * is encapsulated in a try-catch block
+   */
+  if (first_run) {
+    try {
+      output_value = model->forward (input_feeds);
+      first_run = false;
+    } catch(const std::runtime_error& re) {
+      g_critical ("Runtime error while running the model: %s", re.what());
+      return -4;
+    } catch(const std::exception& ex)  {
+      g_critical ("Exception while running the model : %s", ex.what());
+      return -4;
+    } catch (...) {
+      g_critical ("Unknown exception while running the model");
+      return -4;
+    }
+  } else {
+    output_value = model->forward (input_feeds);
+  }
 
   if (output_value.isTensor ()) {
     g_assert (outputTensorMeta.num_tensors == 1);
index 8428a4e..f69d0cd 100644 (file)
@@ -56,6 +56,7 @@ private:
   GstTensorsInfo inputTensorMeta;  /**< The tensor info of input tensors */
   GstTensorsInfo outputTensorMeta;  /**< The tensor info of output tensors */
   bool configured;
+  bool first_run;           /**< must be reset after setting input info */
 
   std::shared_ptr < torch::jit::script::Module > model;