[tfkit] support methods for reading tf attributes (#3122)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 27 Mar 2019 04:41:17 +0000 (13:41 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 27 Mar 2019 04:41:17 +0000 (13:41 +0900)
* [tfkit] support methods for reading tf attributes

This will add some methods to read tensorflow model attributes

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* fix for comments

contrib/tfkit/src/Support.cpp [new file with mode: 0644]
contrib/tfkit/src/Support.hpp [new file with mode: 0644]

diff --git a/contrib/tfkit/src/Support.cpp b/contrib/tfkit/src/Support.cpp
new file mode 100644 (file)
index 0000000..70ef2ee
--- /dev/null
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "Support.hpp"
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+
+namespace tfkit
+{
+namespace tf
+{
+
+bool HasAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
+{
+  return node.attr().count(attr_name) > 0;
+}
+
+tensorflow::DataType GetDataTypeAttr(const tensorflow::NodeDef &node, const std::string &attr_name)
+{
+  assert(HasAttr(node, attr_name));
+  const auto &attr = node.attr().at(attr_name);
+  assert(attr.value_case() == tensorflow::AttrValue::kType);
+  return attr.type();
+}
+
+tensorflow::TensorProto *GetTensorAttr(tensorflow::NodeDef &node, const std::string &attr_name)
+{
+  assert(HasAttr(node, attr_name));
+  tensorflow::AttrValue &attr = node.mutable_attr()->at(attr_name);
+  assert(attr.value_case() == tensorflow::AttrValue::kTensor);
+  return attr.mutable_tensor();
+}
+
+int GetElementCount(const tensorflow::TensorShapeProto &shape)
+{
+  int count = -1;
+
+  for (auto &d : shape.dim())
+  {
+    if (d.size() == 0)
+    {
+      count = 0;
+      break;
+    }
+    if (count == -1)
+      count = 1;
+
+    count *= d.size();
+  }
+  return count;
+}
+
+} // namespace tf
+} // namespace tfkit
diff --git a/contrib/tfkit/src/Support.hpp b/contrib/tfkit/src/Support.hpp
new file mode 100644 (file)
index 0000000..d978713
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef __SUPPORT_H__
+#define __SUPPORT_H__
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <string>
+
+namespace tfkit
+{
+namespace tf
+{
+
+bool HasAttr(const tensorflow::NodeDef &, const std::string &);
+tensorflow::DataType GetDataTypeAttr(const tensorflow::NodeDef &, const std::string &);
+tensorflow::TensorProto *GetTensorAttr(tensorflow::NodeDef &, const std::string &);
+/// GetElementCount returns -1 for rank-0 tensor shape
+int GetElementCount(const tensorflow::TensorShapeProto &);
+
+} // namespace tf
+} // namespace tfkit
+
+#endif // __SUPPORT_H__