From 5017959155fd2e3d3e31ba234251a08bd4b180b5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 3 Jul 2019 16:29:42 +0900 Subject: [PATCH] [plier/tf] copy_shape() function (#4066) * [plier/tf] copy_shape() function This adds copy_shape() function that copies tensorflow shape proto to angkor shape. Signed-off-by: Hyun Sik Yoon * comment --- contrib/plier-tf/CMakeLists.txt | 1 + contrib/plier-tf/include/plier/tf/Convert.h | 10 ++++++++++ contrib/plier-tf/requires.cmake | 1 + contrib/plier-tf/src/Convert.cpp | 31 +++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+) create mode 100644 contrib/plier-tf/requires.cmake diff --git a/contrib/plier-tf/CMakeLists.txt b/contrib/plier-tf/CMakeLists.txt index 8e9273d..25ddc0b 100644 --- a/contrib/plier-tf/CMakeLists.txt +++ b/contrib/plier-tf/CMakeLists.txt @@ -18,4 +18,5 @@ add_library(plier_tf STATIC ${SOURCES}) set_target_properties(plier_tf PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(plier_tf PUBLIC include) +target_link_libraries(plier_tf angkor) target_link_libraries(plier_tf plier_tf_proto) diff --git a/contrib/plier-tf/include/plier/tf/Convert.h b/contrib/plier-tf/include/plier/tf/Convert.h index d0493cb..797688a 100644 --- a/contrib/plier-tf/include/plier/tf/Convert.h +++ b/contrib/plier-tf/include/plier/tf/Convert.h @@ -18,6 +18,8 @@ #ifndef __PLIER_TF_CONVERT_H__ #define __PLIER_TF_CONVERT_H__ +#include + #include #include @@ -54,6 +56,14 @@ enum class DataLayout const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name); +/** + * @brief Copy shape defined in TensorShapeProto to angkor shape + * + * @note Unknown dimension is not supported + */ +void copy_shape(const tensorflow::TensorShapeProto &tf_shape, + nncc::core::ADT::tensor::Shape &to_shape); + } // namespace tf } // namespace plier diff --git a/contrib/plier-tf/requires.cmake b/contrib/plier-tf/requires.cmake new file mode 100644 index 0000000..654db88 --- /dev/null +++ b/contrib/plier-tf/requires.cmake @@ -0,0 +1 @@ +require("angkor") diff --git a/contrib/plier-tf/src/Convert.cpp b/contrib/plier-tf/src/Convert.cpp index aa3b754..27bfe54 100644 --- a/contrib/plier-tf/src/Convert.cpp +++ b/contrib/plier-tf/src/Convert.cpp @@ -17,6 +17,8 @@ #include +#include + #include #include @@ -118,5 +120,34 @@ const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::str throw std::runtime_error("unknown data layout"); } +void copy_shape(const tensorflow::TensorShapeProto &tf_shape, + nncc::core::ADT::tensor::Shape &to_shape) +{ + assert(!tf_shape.unknown_rank()); + + int64_t tf_rank = tf_shape.dim_size(); + assert(tf_rank < std::numeric_limits::max()); + + int32_t rank = static_cast(tf_rank); + to_shape.resize(rank); + + for (int32_t d = 0; d < rank; d++) + { + int64_t dim_value = tf_shape.dim(d).size(); + assert(dim_value < std::numeric_limits::max()); + + if (dim_value >= 0ULL) + { + uint32_t dim_value32 = static_cast(dim_value); + to_shape.dim(d) = dim_value32; + } + else + { + throw std::runtime_error("Cannot handle unknown dimension"); + // TODO support unknown dimension + } + } +} + } // namespace tf } // namespace plier -- 2.7.4