[moco/ONNX] Introduce Constant operation (#3626)
author남궁석/On-Device Lab(SR)/Engineer/삼성전자 <sk.namkoong@samsung.com>
Thu, 30 May 2019 04:16:53 +0000 (13:16 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 30 May 2019 04:16:53 +0000 (13:16 +0900)
* [moco/ONNX] Introduce Constant operation

This commit will introduce constant operation for moco ONNX frontend
and related tests

Signed-off-by: Seok NamKoong <sk.namkoong@samsung.com>
* add assert

* add version diff description

contrib/moco/lib/frontend/onnx/src/Op/Constant.cpp [new file with mode: 0644]
contrib/moco/lib/frontend/onnx/src/Op/Constant.h [new file with mode: 0644]
contrib/moco/lib/frontend/onnx/src/Op/Constant_V1.cpp [new file with mode: 0644]
contrib/moco/lib/frontend/onnx/src/Op/Constant_V9.cpp [new file with mode: 0644]
contrib/moco/test/onnx/Const_000/test.pbtxt [new file with mode: 0644]

diff --git a/contrib/moco/lib/frontend/onnx/src/Op/Constant.cpp b/contrib/moco/lib/frontend/onnx/src/Op/Constant.cpp
new file mode 100644 (file)
index 0000000..c14d272
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Constant.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+bool ConstantGraphBuilder::validate(OpsetVersion opset_version, const ::onnx::NodeProto &node) const
+{
+  if (opset_version >= 9)
+    return Constant_V9().validate(node);
+  else if (opset_version >= 1)
+    return Constant_V1().validate(node);
+  else
+    throw std::runtime_error("Invalid ONNX IR version");
+}
+
+void ConstantGraphBuilder::build(OpsetVersion opset_version, const ::onnx::NodeProto &node,
+                                 GraphBuilderContext *context) const
+{
+  if (opset_version >= 9)
+    Constant_V9().build(node, context);
+  else if (opset_version >= 1)
+    Constant_V1().build(node, context);
+  else
+    throw std::runtime_error("Invalid ONNX IR version");
+}
+
+} // namespace onnx
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Constant, ConstantGraphBuilder)
diff --git a/contrib/moco/lib/frontend/onnx/src/Op/Constant.h b/contrib/moco/lib/frontend/onnx/src/Op/Constant.h
new file mode 100644 (file)
index 0000000..e25441d
--- /dev/null
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBuilder.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+/**
+  * @brief GraphBuilder for Constant(since version 1) node
+  */
+class Constant_V1
+{
+public:
+  bool validate(const ::onnx::NodeProto &) const;
+  void build(const ::onnx::NodeProto &, GraphBuilderContext *) const;
+};
+
+/**
+  * @brief GraphBuilder for Constant(since version 9) node
+  * @note Until version 1, only FLOAT16, FLOAT, DOUBLE was supported
+  *       Since version 9, all types are supported
+  */
+class Constant_V9
+{
+public:
+  bool validate(const ::onnx::NodeProto &) const;
+  void build(const ::onnx::NodeProto &, GraphBuilderContext *) const;
+};
+
+/**
+  * @brief GraphBuilder for Constant node
+  */
+class ConstantGraphBuilder : public GraphBuilder
+{
+public:
+  bool validate(OpsetVersion, const ::onnx::NodeProto &) const;
+  void build(OpsetVersion, const ::onnx::NodeProto &, GraphBuilderContext *) const;
+};
+
+} // namespace onnx
+} // namespace moco
diff --git a/contrib/moco/lib/frontend/onnx/src/Op/Constant_V1.cpp b/contrib/moco/lib/frontend/onnx/src/Op/Constant_V1.cpp
new file mode 100644 (file)
index 0000000..f442f56
--- /dev/null
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Constant.h"
+#include "Convert.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+bool Constant_V1::validate(const ::onnx::NodeProto &node) const
+{
+  if (node.attribute_size() == 0 || !node.attribute(0).has_t())
+    return false;
+
+  auto type = moco::onnx::tensor_dtype_as_string(node.attribute(0).t().data_type());
+  if (type.compare("FLOAT16") != 0 && type.compare("FLOAT") != 0 && type.compare("DOUBLE") != 0)
+    return false;
+
+  return true;
+}
+
+void Constant_V1::build(const ::onnx::NodeProto &node, GraphBuilderContext *context) const
+{
+  assert(context != nullptr);
+
+  loco::Graph *graph = context->graph();
+  SymbolTable *nodes = context->nodes();
+
+  // Create a "ConstGen" node for Constant
+  auto const_node = graph->nodes()->create<loco::ConstGen>();
+  auto tensor_attribute = node.attribute().Get(0).t();
+  const_node->dtype(as_loco_datatype(tensor_attribute.data_type()));
+  const_node->rank(tensor_attribute.dims_size());
+  // TODO Support other data types
+  assert(const_node->dtype() == loco::DataType::FLOAT32);
+  const_node->size<loco::DataType::FLOAT32>(tensor_attribute.float_data_size());
+
+  for (uint32_t i = 0; i < const_node->rank(); ++i)
+  {
+    const_node->dim(i) = loco::make_dimension(tensor_attribute.dims(i));
+  }
+
+  // TODO Support other data types
+  for (int i = 0; i < tensor_attribute.float_data_size(); ++i)
+  {
+    const_node->at<loco::DataType::FLOAT32>(i) = tensor_attribute.float_data(i);
+  }
+
+  nodes->enroll(node.name(), const_node);
+  nodes->enroll(node.output(0), const_node);
+}
+
+} // namespace onnx
+} // namespace moco
diff --git a/contrib/moco/lib/frontend/onnx/src/Op/Constant_V9.cpp b/contrib/moco/lib/frontend/onnx/src/Op/Constant_V9.cpp
new file mode 100644 (file)
index 0000000..d2755c9
--- /dev/null
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Constant.h"
+#include "Convert.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+bool Constant_V9::validate(const ::onnx::NodeProto &node) const
+{
+  if (node.attribute_size() == 0 || !node.attribute(0).has_t())
+    return false;
+
+  return true;
+}
+
+void Constant_V9::build(const ::onnx::NodeProto &node, GraphBuilderContext *context) const
+{
+  assert(context != nullptr);
+
+  loco::Graph *graph = context->graph();
+  SymbolTable *nodes = context->nodes();
+
+  // Create a "ConstGen" node for Constant
+  auto const_node = graph->nodes()->create<loco::ConstGen>();
+  auto tensor_attribute = node.attribute().Get(0).t();
+  const_node->dtype(as_loco_datatype(tensor_attribute.data_type()));
+  const_node->rank(tensor_attribute.dims_size());
+  // TODO Support other data types
+  assert(const_node->dtype() == loco::DataType::FLOAT32);
+  const_node->size<loco::DataType::FLOAT32>(tensor_attribute.float_data_size());
+
+  for (uint32_t i = 0; i < const_node->rank(); ++i)
+  {
+    const_node->dim(i) = loco::make_dimension(tensor_attribute.dims(i));
+  }
+
+  // TODO Support other data types
+  for (int i = 0; i < tensor_attribute.float_data_size(); ++i)
+  {
+    const_node->at<loco::DataType::FLOAT32>(i) = tensor_attribute.float_data(i);
+  }
+
+  nodes->enroll(node.name(), const_node);
+  nodes->enroll(node.output(0), const_node);
+}
+
+} // namespace onnx
+} // namespace moco
diff --git a/contrib/moco/test/onnx/Const_000/test.pbtxt b/contrib/moco/test/onnx/Const_000/test.pbtxt
new file mode 100644 (file)
index 0000000..c5ae298
--- /dev/null
@@ -0,0 +1,52 @@
+# Latest IR_VERSION of 1.4.1 version is 4
+# https://github.com/onnx/onnx/blob/rel-1.4.1/onnx/onnx.proto3
+ir_version: 4
+
+# Opset version of IR_VERSION 4 is 9
+# https://github.com/onnx/onnx/blob/rel-1.4.1/onnx/defs/operator_sets.h
+opset_import {
+  version: 9
+}
+
+graph {
+  name: "Const_000"
+
+  node {
+    name: "const_node"
+    output: "output:0"
+    op_type: "Constant"
+    attribute {
+      name: "const/value"
+      t {
+        dims: 2
+        dims: 3
+        data_type: 1 # FLOAT type
+        float_data: 1.1
+        float_data: 2.2
+        float_data: 3.3
+        float_data: 4.4
+        float_data: 5.5
+        float_data: 6.6
+        name: "const_tensor"
+      }
+      type: TENSOR
+    }
+  }
+
+  output {
+    name: "output:0"
+    type {
+      tensor_type {
+        elem_type: 1 # FLOAT type
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+        }
+      }
+    }
+  }
+}