2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef MIR_ONNX_ATTRIBUTE_HELPERS_H
18 #define MIR_ONNX_ATTRIBUTE_HELPERS_H
20 #include "onnx/onnx.pb.h"
33 template <typename T> T getAttributeValue(const onnx::AttributeProto &attribute) = delete;
35 template <> inline float getAttributeValue(const onnx::AttributeProto &attribute)
37 assert(attribute.type() == onnx::AttributeProto::FLOAT);
41 template <> inline std::int64_t getAttributeValue(const onnx::AttributeProto &attribute)
43 assert(attribute.type() == onnx::AttributeProto::INT);
47 template <> inline std::string getAttributeValue(const onnx::AttributeProto &attribute)
49 assert(attribute.type() == onnx::AttributeProto::STRING);
53 template <> inline onnx::TensorProto getAttributeValue(const onnx::AttributeProto &attribute)
55 assert(attribute.type() == onnx::AttributeProto::TENSOR);
60 inline std::vector<std::int32_t> getAttributeValue(const onnx::AttributeProto &attribute)
62 assert(attribute.type() == onnx::AttributeProto::INTS);
63 // TODO Check that values fit.
64 return {attribute.ints().cbegin(), attribute.ints().cend()};
68 inline std::vector<std::int64_t> getAttributeValue(const onnx::AttributeProto &attribute)
70 assert(attribute.type() == onnx::AttributeProto::INTS);
71 return {attribute.ints().cbegin(), attribute.ints().cend()};
74 inline const onnx::AttributeProto *findAttribute(const onnx::NodeProto &node,
75 const std::string &name)
77 const auto &attributes = node.attribute();
78 const auto it = std::find_if(
79 attributes.cbegin(), attributes.cend(),
80 [&name](const onnx::AttributeProto &attribute) { return attribute.name() == name; });
81 if (it == attributes.cend())
86 template <typename T> T getAttributeValue(const onnx::NodeProto &node, const std::string &name)
88 const auto *attribute = findAttribute(node, name);
89 if (attribute == nullptr)
90 throw std::runtime_error("Cannot find attribute '" + name + "' in node '" + node.name() + "'.");
91 return getAttributeValue<T>(*attribute);
95 T getAttributeValue(const onnx::NodeProto &node, const std::string &name, T default_value)
97 const auto *attribute = findAttribute(node, name);
98 if (attribute == nullptr)
99 return std::move(default_value);
100 return getAttributeValue<T>(*attribute);
103 } // namespace mir_onnx
105 #endif // MIR_ONNX_ATTRIBUTE_HELPERS_H