[moco/ONNX] Apply opset version to operator and frontend (#3557)
author남궁석/On-Device Lab(SR)/Engineer/삼성전자 <sk.namkoong@samsung.com>
Wed, 22 May 2019 08:19:29 +0000 (17:19 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 22 May 2019 08:19:29 +0000 (17:19 +0900)
* [moco/ONNX] Apply opset version to operator and frontend

This commit will apply opset version to moco ONNX frontend and
ONNX operators.

Signed-off-by: Seok NamKoong <sk.namkoong@samsung.com>
* Repositioning REGISTER_OP_BUILDER

contrib/moco/lib/frontend/onnx/src/Frontend.cpp
contrib/moco/lib/frontend/onnx/src/GraphBuilder.h
contrib/moco/lib/frontend/onnx/src/Op/Identity.cpp
contrib/moco/lib/frontend/onnx/src/Op/Identity.h [new file with mode: 0644]
contrib/moco/lib/frontend/onnx/src/Op/Identity_V1.cpp [new file with mode: 0644]

index cb8cc95..480a7a9 100644 (file)
@@ -124,12 +124,12 @@ void convert_graph(::onnx::ModelProto &onnx_model_proto, loco::Graph *graph)
   {
     if (const auto *graph_builder = moco::onnx::GraphBuilderRegistry::get().lookup(n.op_type()))
     {
-      if (!graph_builder->validate(n))
+      if (!graph_builder->validate(opset_version, n))
       {
         throw std::runtime_error{"Invalid operator: " + n.op_type()};
       }
 
-      graph_builder->build(n, &gb_context);
+      graph_builder->build(opset_version, n, &gb_context);
     }
     else
     {
index 2c0c946..7271eb8 100644 (file)
@@ -28,12 +28,15 @@ namespace onnx
 
 /**
 * @brief Parent class of onnx operation graph builders
+* @note GraphBuilder call proper build and validate function according to opset version
 */
 class GraphBuilder
 {
 public:
-  virtual bool validate(const ::onnx::NodeProto &) const { return true; }
-  virtual void build(const ::onnx::NodeProto &, GraphBuilderContext *) const = 0;
+  using OpsetVersion = int64_t;
+
+  virtual bool validate(OpsetVersion, const ::onnx::NodeProto &) const { return true; }
+  virtual void build(OpsetVersion, const ::onnx::NodeProto &, GraphBuilderContext *) const = 0;
   virtual ~GraphBuilder() {}
 };
 
index 71df7b5..7238ffc 100644 (file)
@@ -14,7 +14,7 @@
  * limitations under the License.
  */
 
-#include "GraphBuilder.h"
+#include "Identity.h"
 
 #include <cassert>
 
@@ -23,18 +23,10 @@ namespace moco
 namespace onnx
 {
 
-/**
-  * @brief GraphBuilder for Identity node
-  */
-class IdentityGraphBuilder : public GraphBuilder
-{
-public:
-  bool validate(const ::onnx::NodeProto &) const override;
-  void build(const ::onnx::NodeProto &, GraphBuilderContext *) const override;
-};
-
+// Deprecated
 bool IdentityGraphBuilder::validate(const ::onnx::NodeProto &node) const { return true; }
 
+// Deprecated
 void IdentityGraphBuilder::build(const ::onnx::NodeProto &node, GraphBuilderContext *context) const
 {
   assert(context != nullptr);
@@ -56,6 +48,23 @@ void IdentityGraphBuilder::build(const ::onnx::NodeProto &node, GraphBuilderCont
   }
 }
 
+bool IdentityGraphBuilder::validate(OpsetVersion opset_version, const ::onnx::NodeProto &node) const
+{
+  if (opset_version >= 1)
+    return Identity_V1().validate(node);
+  else
+    throw std::runtime_error("Invalid ONNX IR version");
+}
+
+void IdentityGraphBuilder::build(OpsetVersion opset_version, const ::onnx::NodeProto &node,
+                                 GraphBuilderContext *context) const
+{
+  if (opset_version >= 1)
+    Identity_V1().build(node, context);
+  else
+    throw std::runtime_error("Invalid ONNX IR version");
+}
+
 } // namespace onnx
 } // namespace moco
 
diff --git a/contrib/moco/lib/frontend/onnx/src/Op/Identity.h b/contrib/moco/lib/frontend/onnx/src/Op/Identity.h
new file mode 100644 (file)
index 0000000..227574c
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBuilder.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+/**
+  * @brief GraphBuilder for Identity(since version 1) node
+  */
+class Identity_V1
+{
+public:
+  bool validate(const ::onnx::NodeProto &) const;
+  void build(const ::onnx::NodeProto &, GraphBuilderContext *) const;
+};
+
+/**
+  * @brief GraphBuilder for Identity node
+  */
+class IdentityGraphBuilder : public GraphBuilder
+{
+public:
+  // Deprecated
+  bool validate(const ::onnx::NodeProto &) const;
+  // Deprecated
+  void build(const ::onnx::NodeProto &, GraphBuilderContext *) const;
+
+  bool validate(OpsetVersion, const ::onnx::NodeProto &) const;
+  void build(OpsetVersion, const ::onnx::NodeProto &, GraphBuilderContext *) const;
+};
+
+} // namespace onnx
+} // namespace moco
diff --git a/contrib/moco/lib/frontend/onnx/src/Op/Identity_V1.cpp b/contrib/moco/lib/frontend/onnx/src/Op/Identity_V1.cpp
new file mode 100644 (file)
index 0000000..6ae6558
--- /dev/null
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Identity.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+bool Identity_V1::validate(const ::onnx::NodeProto &) const { return true; }
+
+void Identity_V1::build(const ::onnx::NodeProto &node, GraphBuilderContext *context) const
+{
+  assert(context != nullptr);
+
+  loco::Graph *graph = context->graph();
+  SymbolTable *nodes = context->nodes();
+  SymbolTable *input_names = context->input_names();
+
+  // Create a "Forward" node for Identity
+  auto forward_node = graph->nodes()->create<loco::Forward>();
+
+  nodes->enroll(node.name(), forward_node);
+  nodes->enroll(node.output(0), forward_node);
+
+  // Record all inputs to forward_node
+  for (int i = 0; i < node.input_size(); ++i)
+  {
+    const auto &input_name = node.input(i);
+    input_names->list(forward_node, input_name);
+  }
+}
+
+} // namespace onnx
+} // namespace moco