--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "Support.hpp"
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+
+namespace tfkit
+{
+namespace tf
+{
+
+bool HasAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
+{
+ return node.attr().count(attr_name) > 0;
+}
+
+tensorflow::DataType GetDataTypeAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
+{
+ assert(HasAttr(node, attr_name));
+ const auto &attr = node.attr().at(attr_name);
+ assert(attr.value_case() == tensorflow::AttrValue::kType);
+ return attr.type();
+}
+
+tensorflow::TensorProto *GetTensorAttr(tensorflow::NodeDef &node, const std::string &attr_name)
+{
+ assert(HasAttr(node, attr_name));
+ tensorflow::AttrValue &attr = node.mutable_attr()->at(attr_name);
+ assert(attr.value_case() == tensorflow::AttrValue::kTensor);
+ return attr.mutable_tensor();
+}
+
+int GetElementCount(const tensorflow::TensorShapeProto &shape)
+{
+ int count = -1;
+
+ for (auto &d : shape.dim())
+ {
+ if (d.size() == 0)
+ {
+ count = 0;
+ break;
+ }
+ if (count == -1)
+ count = 1;
+
+ count *= d.size();
+ }
+ return count;
+}
+
+} // namespace tf
+} // namespace tfkit
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef __SUPPORT_H__
+#define __SUPPORT_H__
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <string>
+
+namespace tfkit
+{
+namespace tf
+{
+
+bool HasAttr(const tensorflow::NodeDef &, const std::string &);
+tensorflow::DataType GetDataTypeAttr(const tensorflow::NodeDef &, const std::string &);
+tensorflow::TensorProto *GetTensorAttr(tensorflow::NodeDef &, const std::string &);
+/// GetElementCount returns -1 for rank-0 tensor shape
+int GetElementCount(const tensorflow::TensorShapeProto &);
+
+} // namespace tf
+} // namespace tfkit
+
+#endif // __SUPPORT_H__