From 3c2b8878199851f829dffb53eeed397096bb8564 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Tue, 6 Oct 2020 12:34:47 +0900 Subject: [PATCH] [Tensor/model] Move some member functions to private Move BroadcastInfo to private Move printMetrics to private **Self evaluation:** 1. Build test: [x]Passed [ ]Failed [ ]Skipped 2. Run test: [x]Passed [ ]Failed [ ]Skipped Signed-off-by: Parichay Kapoor --- nntrainer/include/neuralnet.h | 14 +++++++------- nntrainer/include/tensor.h | 24 ++---------------------- nntrainer/src/tensor.cpp | 24 +++++++++++++++++++++++- 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/nntrainer/include/neuralnet.h b/nntrainer/include/neuralnet.h index 625146d..16b5db7 100644 --- a/nntrainer/include/neuralnet.h +++ b/nntrainer/include/neuralnet.h @@ -279,13 +279,6 @@ public: */ void printPreset(std::ostream &out, unsigned int preset); - /** - * @brief print metrics function for neuralnet - * @param[in] out outstream - * @param[in] flags verbosity from ml_train_summary_type_e - */ - void printMetrics(std::ostream &out, unsigned int flags = 0); - private: /** * @brief Print Options when printing layer info @@ -447,6 +440,13 @@ private: * @brief Update batch size of the model as well as its layers/dataset */ void setBatchSize(unsigned int batch_size); + + /** + * @brief print metrics function for neuralnet + * @param[in] out outstream + * @param[in] flags verbosity from ml_train_summary_type_e + */ + void printMetrics(std::ostream &out, unsigned int flags = 0); }; } /* namespace nntrainer */ diff --git a/nntrainer/include/tensor.h b/nntrainer/include/tensor.h index 260df2d..cb9993c 100644 --- a/nntrainer/include/tensor.h +++ b/nntrainer/include/tensor.h @@ -39,28 +39,6 @@ namespace nntrainer { class LazyTensor; /** - * @struct External Loop Info for broadcasted info - * @brief External Loop Info for broadcasted iteration. Please refer to - * DISABLED_private_external_loop_n in unittest_nntrainer_tensor. - * @note This should better be implemented in iterator fashion before used - * extensively. - */ -struct BroadcastInfo { - - /** - * @brief Construct a new External Loop Info object - * - */ - BroadcastInfo() : strides{0, 0, 0, 0} {} - - unsigned int buffer_size; /**< virtual size of the buffer */ - int buffer_axis; /**< the smallest axis that should be looped. - -1 means no loop needed*/ - std::array - strides; /**< modified strides for the loop */ -}; - -/** * @class Tensor Class for Calculation * @brief Tensor Class for Calculation */ @@ -593,6 +571,8 @@ private: return (b * strides[0] + c * strides[1] + h * strides[2] + w * strides[3]); } + struct BroadcastInfo; + /** * @brief Applies the given operator to the tensor with the passed argument * @param[in] m Tensor diff --git a/nntrainer/src/tensor.cpp b/nntrainer/src/tensor.cpp index 16c1378..ff0de42 100644 --- a/nntrainer/src/tensor.cpp +++ b/nntrainer/src/tensor.cpp @@ -208,6 +208,28 @@ int Tensor::add_i(float const &value) { Tensor Tensor::add(float const &value) { CLONE_OP_I(add_i, value); } /** + * @struct External Loop Info for broadcasted info + * @brief External Loop Info for broadcasted iteration. Please refer to + * DISABLED_private_external_loop_n in unittest_nntrainer_tensor. + * @note This should better be implemented in iterator fashion before used + * extensively. + */ +struct Tensor::BroadcastInfo { + + /** + * @brief Construct a new External Loop Info object + * + */ + BroadcastInfo() : strides{0, 0, 0, 0} {} + + unsigned int buffer_size; /**< virtual size of the buffer */ + int buffer_axis; /**< the smallest axis that should be looped. + -1 means no loop needed*/ + std::array + strides; /**< modified strides for the loop */ +}; + +/** * @brief Add Tensor Element by Element without mem copy * @param[in] m Tensor to be added * #retval #ML_ERROR_NONE Successful @@ -798,7 +820,7 @@ Tensor Tensor::standardization() const { return result; } -BroadcastInfo Tensor::computeBroadcastInfo(const Tensor &m) { +Tensor::BroadcastInfo Tensor::computeBroadcastInfo(const Tensor &m) { if (m.length() > this->length()) throw exception::not_supported("broadcasting *this is not supported"); -- 2.7.4