[refactor] Restructure getStringDataType function
authorDonghyeon Jeong <dhyeon.jeong@samsung.com>
Fri, 2 Aug 2024 04:02:23 +0000 (13:02 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 5 Aug 2024 01:01:43 +0000 (10:01 +0900)
This patch updates the getStringDataType function structure to utilize method overriding.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
nntrainer/tensor/float_tensor.h
nntrainer/tensor/half_tensor.h
nntrainer/tensor/tensor_base.h

index dd976d91d9256e84077b45ae1ce7fcff70b62ce4..017433e7c9feed5d0757eb4dbc79aad8489deb1c 100644 (file)
@@ -500,6 +500,12 @@ private:
                                           const float *, float *)>
                          v_func,
                        Tensor &output) const;
+
+  /**
+   * @brief  Get the Data Type String object
+   * @return std::string of tensor data type (FP32)
+   */
+  std::string getStringDataType() const override { return "FP32"; }
 };
 
 } // namespace nntrainer
index e0dfd77748e3e959bdecfcf9dc5771f021a80aee..8db09c0cce00db89e40e58a201d4945846a47617 100644 (file)
@@ -491,6 +491,12 @@ private:
                                           const _FP16 *, _FP16 *)>
                          v_func,
                        Tensor &output) const;
+
+  /**
+   * @brief  Get the Data Type String object
+   * @return std::string of tensor data type (FP16)
+   */
+  std::string getStringDataType() const override { return "FP16"; }
 };
 
 } // namespace nntrainer
index 576ed5db1f9a176c9ebe0107168658fbdb77a7c7..8caaeadd340a84765a29b1ee05febf5087ae5726 100644 (file)
@@ -735,28 +735,10 @@ protected:
   /**
    * @brief  Get the Data Type String object
    * @return std::string of tensor data type
+   * @note   TensorBase::getStringDataType() should not be called. Please define
+   * this function in the derived class to the corresponding data type.
    */
-  std::string getStringDataType() const {
-    std::string res;
-    switch (getDataType()) {
-    case Tdatatype::FP32:
-      res = "FP32";
-      break;
-    case Tdatatype::FP16:
-      res = "FP16";
-      break;
-    case Tdatatype::QINT8:
-      res = "QINT8";
-      break;
-    case Tdatatype::QINT4:
-      res = "QINT4";
-      break;
-    default:
-      res = "Undefined type";
-      break;
-    }
-    return res;
-  }
+  virtual std::string getStringDataType() const { return "Undefined type"; }
 };
 
 /**