Added default constructor for RNNCellBase, fix conversions (#1370)
authorIvan Tikhonov <ivan.tikhonov@intel.com>
Mon, 20 Jul 2020 11:15:37 +0000 (14:15 +0300)
committerGitHub <noreply@github.com>
Mon, 20 Jul 2020 11:15:37 +0000 (14:15 +0300)
inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp
inference-engine/src/readers/ir_reader/ie_ir_parser.hpp
ngraph/src/ngraph/op/util/rnn_cell_base.cpp
ngraph/src/ngraph/op/util/rnn_cell_base.hpp

index d334c8a..cd7883a 100644 (file)
@@ -91,6 +91,24 @@ public:
         params[name] = std::to_string(adapter.get());
     }
 
+    void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter) override {
+        std::vector<std::string> data = adapter.get();
+        for (auto& str : data) {
+            std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) {
+                return std::tolower(c);
+            });
+        }
+
+        std::stringstream ss;
+        std::copy(data.begin(), data.end(), std::ostream_iterator<std::string>(ss, ","));
+        params[name] = ss.str();
+    }
+
+    void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<float>>& adapter) override {
+        auto data = adapter.get();
+        params[name] = joinVec(data);
+    }
+
     void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override;
 
 private:
@@ -118,6 +136,9 @@ void InferenceEngine::details::CNNLayerCreator::on_adapter(const std::string& na
     } else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<::ngraph::Strides>>(&adapter)) {
         auto shape = static_cast<::ngraph::Strides&>(*a);
         params[name] = joinVec(shape);
+    } else {
+        THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. "
+                              "Attribute adapter can not be found for " << name << " parameter";
     }
 }
 
index 2f7513c..862622d 100644 (file)
@@ -218,6 +218,9 @@ private:
             } else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) {
                 if (!getStrAttribute(node.child("data"), name, val)) return;
                 static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val);
+            }  else {
+                THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
+                                   << " parameter";
             }
         }
         void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override {
index d931e58..276187f 100644 (file)
@@ -37,6 +37,12 @@ static vector<string> to_lower_case(const vector<string>& vs)
     return res;
 }
 
+op::util::RNNCellBase::RNNCellBase()
+    : m_clip(0.f)
+    , m_hidden_size(0)
+{
+}
+
 op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
                                    float clip,
                                    const vector<string>& activations,
index ad02dfb..9034ddf 100644 (file)
@@ -56,7 +56,7 @@ namespace ngraph
                             const std::vector<float>& activations_alpha,
                             const std::vector<float>& activations_beta);
 
-                RNNCellBase() = default;
+                RNNCellBase();
                 virtual ~RNNCellBase() = default;
 
                 virtual bool visit_attributes(AttributeVisitor& visitor);