--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Constant.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+bool ConstantGraphBuilder::validate(OpsetVersion opset_version, const ::onnx::NodeProto &node) const
+{
+ if (opset_version >= 9)
+ return Constant_V9().validate(node);
+ else if (opset_version >= 1)
+ return Constant_V1().validate(node);
+ else
+ throw std::runtime_error("Invalid ONNX IR version");
+}
+
+void ConstantGraphBuilder::build(OpsetVersion opset_version, const ::onnx::NodeProto &node,
+ GraphBuilderContext *context) const
+{
+ if (opset_version >= 9)
+ Constant_V9().build(node, context);
+ else if (opset_version >= 1)
+ Constant_V1().build(node, context);
+ else
+ throw std::runtime_error("Invalid ONNX IR version");
+}
+
+} // namespace onnx
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Constant, ConstantGraphBuilder)
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBuilder.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+/**
+ * @brief GraphBuilder for Constant(since version 1) node
+ */
+class Constant_V1
+{
+public:
+ bool validate(const ::onnx::NodeProto &) const;
+ void build(const ::onnx::NodeProto &, GraphBuilderContext *) const;
+};
+
+/**
+ * @brief GraphBuilder for Constant(since version 9) node
+ * @note Until version 1, only FLOAT16, FLOAT, DOUBLE was supported
+ * Since version 9, all types are supported
+ */
+class Constant_V9
+{
+public:
+ bool validate(const ::onnx::NodeProto &) const;
+ void build(const ::onnx::NodeProto &, GraphBuilderContext *) const;
+};
+
+/**
+ * @brief GraphBuilder for Constant node
+ */
+class ConstantGraphBuilder : public GraphBuilder
+{
+public:
+ bool validate(OpsetVersion, const ::onnx::NodeProto &) const;
+ void build(OpsetVersion, const ::onnx::NodeProto &, GraphBuilderContext *) const;
+};
+
+} // namespace onnx
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Constant.h"
+#include "Convert.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+bool Constant_V1::validate(const ::onnx::NodeProto &node) const
+{
+ if (node.attribute_size() == 0 || !node.attribute(0).has_t())
+ return false;
+
+ auto type = moco::onnx::tensor_dtype_as_string(node.attribute(0).t().data_type());
+ if (type.compare("FLOAT16") != 0 && type.compare("FLOAT") != 0 && type.compare("DOUBLE") != 0)
+ return false;
+
+ return true;
+}
+
+void Constant_V1::build(const ::onnx::NodeProto &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *nodes = context->nodes();
+
+ // Create a "ConstGen" node for Constant
+ auto const_node = graph->nodes()->create<loco::ConstGen>();
+ auto tensor_attribute = node.attribute().Get(0).t();
+ const_node->dtype(as_loco_datatype(tensor_attribute.data_type()));
+ const_node->rank(tensor_attribute.dims_size());
+ // TODO Support other data types
+ assert(const_node->dtype() == loco::DataType::FLOAT32);
+ const_node->size<loco::DataType::FLOAT32>(tensor_attribute.float_data_size());
+
+ for (uint32_t i = 0; i < const_node->rank(); ++i)
+ {
+ const_node->dim(i) = loco::make_dimension(tensor_attribute.dims(i));
+ }
+
+ // TODO Support other data types
+ for (int i = 0; i < tensor_attribute.float_data_size(); ++i)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = tensor_attribute.float_data(i);
+ }
+
+ nodes->enroll(node.name(), const_node);
+ nodes->enroll(node.output(0), const_node);
+}
+
+} // namespace onnx
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Constant.h"
+#include "Convert.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+bool Constant_V9::validate(const ::onnx::NodeProto &node) const
+{
+ if (node.attribute_size() == 0 || !node.attribute(0).has_t())
+ return false;
+
+ return true;
+}
+
+void Constant_V9::build(const ::onnx::NodeProto &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *nodes = context->nodes();
+
+ // Create a "ConstGen" node for Constant
+ auto const_node = graph->nodes()->create<loco::ConstGen>();
+ auto tensor_attribute = node.attribute().Get(0).t();
+ const_node->dtype(as_loco_datatype(tensor_attribute.data_type()));
+ const_node->rank(tensor_attribute.dims_size());
+ // TODO Support other data types
+ assert(const_node->dtype() == loco::DataType::FLOAT32);
+ const_node->size<loco::DataType::FLOAT32>(tensor_attribute.float_data_size());
+
+ for (uint32_t i = 0; i < const_node->rank(); ++i)
+ {
+ const_node->dim(i) = loco::make_dimension(tensor_attribute.dims(i));
+ }
+
+ // TODO Support other data types
+ for (int i = 0; i < tensor_attribute.float_data_size(); ++i)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = tensor_attribute.float_data(i);
+ }
+
+ nodes->enroll(node.name(), const_node);
+ nodes->enroll(node.output(0), const_node);
+}
+
+} // namespace onnx
+} // namespace moco
--- /dev/null
+# Latest IR_VERSION of 1.4.1 version is 4
+# https://github.com/onnx/onnx/blob/rel-1.4.1/onnx/onnx.proto3
+ir_version: 4
+
+# Opset version of IR_VERSION 4 is 9
+# https://github.com/onnx/onnx/blob/rel-1.4.1/onnx/defs/operator_sets.h
+opset_import {
+ version: 9
+}
+
+graph {
+ name: "Const_000"
+
+ node {
+ name: "const_node"
+ output: "output:0"
+ op_type: "Constant"
+ attribute {
+ name: "const/value"
+ t {
+ dims: 2
+ dims: 3
+ data_type: 1 # FLOAT type
+ float_data: 1.1
+ float_data: 2.2
+ float_data: 3.3
+ float_data: 4.4
+ float_data: 5.5
+ float_data: 6.6
+ name: "const_tensor"
+ }
+ type: TENSOR
+ }
+ }
+
+ output {
+ name: "output:0"
+ type {
+ tensor_type {
+ elem_type: 1 # FLOAT type
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+}