[pytorch] Update pytorch extension for pytorch v1.6.0-rc1
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 16 Jul 2020 09:54:03 +0000 (18:54 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Thu, 23 Jul 2020 00:41:13 +0000 (09:41 +0900)
Update pytorch extension for pytorch v1.6.0-rc1
Also pytorch and caffe2 headers do not follow redundant declarations
So, remove redundant declaration warning flag when compiling with them

V2:
Check version of pytorch from pkg-config
Update extension based on version number due to api update

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
ext/nnstreamer/tensor_filter/meson.build
ext/nnstreamer/tensor_filter/tensor_filter_pytorch.cc
meson.build

index 911a46f..1cf4e80 100644 (file)
@@ -144,6 +144,20 @@ if pytorch_support_is_available
 
   nnstreamer_filter_torch_deps = pytorch_support_deps + [glib_dep, gst_dep, nnstreamer_dep]
 
+  pytorch_compile_args = []
+
+  torch_ver_dep = dependency('pytorch', version : '>=1.2.0', required : false)
+  if torch_ver_dep.found()
+
+    pytorch_compile_args += '-DPYTORCH_VER_ATLEAST_1_2_0=1'
+  endif
+
+  pytorch_extra_dep = declare_dependency(
+    compile_args : pytorch_compile_args,
+  )
+
+  nnstreamer_filter_torch_deps += pytorch_extra_dep
+
   shared_library('nnstreamer_filter_pytorch',
     nnstreamer_filter_torch_sources,
     dependencies: nnstreamer_filter_torch_deps,
index 5b5f299..802687f 100644 (file)
@@ -77,7 +77,7 @@ private:
   bool configured;
   bool first_run;           /**< must be reset after setting input info */
 
-  std::shared_ptr < torch::jit::script::Module > model;
+  std::shared_ptr<torch::jit::script::Module> model;
 
   void setAccelerator (const char *accelerators);
   tensor_type getTensorTypeFromTorch (torch::Dtype torchType);
@@ -197,7 +197,12 @@ TorchCore::loadModel ()
     return -1;
   }
 
+#ifdef PYTORCH_VER_ATLEAST_1_2_0
+  model = std::make_shared<torch::jit::script::Module>(torch::jit::load (model_path));
+#else
   model = torch::jit::load (model_path);
+#endif
+
   if (model == nullptr) {
     ml_loge ("Failed to read graph.");
     return -2;
@@ -477,12 +482,20 @@ TorchCore::invoke (const GstTensorMemory * input, GstTensorMemory * output)
       ml_loge ("Output Tensor Information is not valid");
       return -2;
     }
+#ifdef PYTORCH_VER_ATLEAST_1_2_0
+  } else if (output_value.isList ()) {
+    c10::ArrayRef<torch::jit::IValue> output_ref_list =
+      output_value.toListRef ();
+    std::vector<torch::jit::IValue> output_list (
+        output_ref_list.begin (), output_ref_list.end ());
+#else
   } else if (output_value.isGenericList ()) {
-    std::vector < torch::jit::IValue > output_list =
-        output_value.toGenericListRef ();
+    c10::ArrayRef<torch::jit::IValue> output_list =
+      output_value.toGenericListRef ();
+#endif
     g_assert (outputTensorMeta.num_tensors == output_list.size ());
     int idx = 0;
-  for (auto & ivalue_element:output_list) {
+    for (auto & ivalue_element:output_list) {
       if (processIValue (ivalue_element, &output[idx++])) {
         ml_loge ("Output Tensor Information is not valid");
         return -2;
index 862c272..e983de7 100644 (file)
@@ -41,8 +41,8 @@ elif not meson.is_cross_build()
   endif
 endif
 
+# Define warning flags for c and cpp
 warning_flags = [
-  '-Wredundant-decls',
   '-Wwrite-strings',
   '-Wformat',
   '-Wformat-nonliteral',
@@ -64,6 +64,7 @@ warning_c_flags = [
   '-Wdeclaration-after-statement'
 ]
 
+# Setup warning flags for c anc cpp
 foreach extra_arg : warning_flags
   if cc.has_argument (extra_arg)
     add_project_arguments([extra_arg], language: 'c')
@@ -288,7 +289,6 @@ foreach feature_name, data : features
 
 endforeach
 
-
 #Definitions enabled by meson_options.txt
 message('Following project_args are going to be included')
 message(project_args)
@@ -296,6 +296,17 @@ foreach name, value: project_args
   add_project_arguments('-D@0@=@1@'.format(name, value), language: ['c', 'cpp'])
 endforeach
 
+# Add redundant declaration flag when caffe2 and pytorch both are disabled
+if not (pytorch_support_is_available and caffe2_support_is_available)
+  redundant_decls_flag = '-Wredundant-decls'
+  if cc.has_argument (redundant_decls_flag)
+    add_project_arguments([redundant_decls_flag], language: 'c')
+  endif
+  if cxx.has_argument (redundant_decls_flag)
+    add_project_arguments([redundant_decls_flag], language: 'cpp')
+  endif
+endif
+
 # Python
 have_python2 = false
 have_python3 = false