From: Donghyeon Jeong Date: Fri, 2 Aug 2024 04:02:23 +0000 (+0900) Subject: [refactor] Restructure getStringDataType function X-Git-Tag: accepted/tizen/7.0/unified/20240830.164841~29 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=32d901cee3c913a27e95ce81923266f6de7a063b;p=platform%2Fcore%2Fml%2Fnntrainer.git [refactor] Restructure getStringDataType function This patch updates the getStringDataType function structure to utilize method overriding. **Self-evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghyeon Jeong --- diff --git a/nntrainer/tensor/float_tensor.h b/nntrainer/tensor/float_tensor.h index dd976d91..017433e7 100644 --- a/nntrainer/tensor/float_tensor.h +++ b/nntrainer/tensor/float_tensor.h @@ -500,6 +500,12 @@ private: const float *, float *)> v_func, Tensor &output) const; + + /** + * @brief Get the Data Type String object + * @return std::string of tensor data type (FP32) + */ + std::string getStringDataType() const override { return "FP32"; } }; } // namespace nntrainer diff --git a/nntrainer/tensor/half_tensor.h b/nntrainer/tensor/half_tensor.h index e0dfd777..8db09c0c 100644 --- a/nntrainer/tensor/half_tensor.h +++ b/nntrainer/tensor/half_tensor.h @@ -491,6 +491,12 @@ private: const _FP16 *, _FP16 *)> v_func, Tensor &output) const; + + /** + * @brief Get the Data Type String object + * @return std::string of tensor data type (FP16) + */ + std::string getStringDataType() const override { return "FP16"; } }; } // namespace nntrainer diff --git a/nntrainer/tensor/tensor_base.h b/nntrainer/tensor/tensor_base.h index 576ed5db..8caaeadd 100644 --- a/nntrainer/tensor/tensor_base.h +++ b/nntrainer/tensor/tensor_base.h @@ -735,28 +735,10 @@ protected: /** * @brief Get the Data Type String object * @return std::string of tensor data type + * @note TensorBase::getStringDataType() should not be called. Please define + * this function in the derived class to the corresponding data type. */ - std::string getStringDataType() const { - std::string res; - switch (getDataType()) { - case Tdatatype::FP32: - res = "FP32"; - break; - case Tdatatype::FP16: - res = "FP16"; - break; - case Tdatatype::QINT8: - res = "QINT8"; - break; - case Tdatatype::QINT4: - res = "QINT4"; - break; - default: - res = "Undefined type"; - break; - } - return res; - } + virtual std::string getStringDataType() const { return "Undefined type"; } }; /**