[ Mixed ] fix apply using casted function
authorjijoong.moon <jijoong.moon@samsung.com>
Fri, 28 Jul 2023 10:49:52 +0000 (19:49 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 21 Aug 2023 06:29:23 +0000 (15:29 +0900)
Describe a commit content (Until 80 colums per line) in detail ASAP.

**Changes proposed in this PR:**
- Added TOC generator for README.md

Resolves:

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
17 files changed:
api/ccapi/include/tensor_dim.h
meson.build
nntrainer/layers/embedding.cpp
nntrainer/layers/input_layer.cpp
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
nntrainer/layers/loss/loss_layer.cpp
nntrainer/tensor/blas_interface.cpp
nntrainer/tensor/blas_interface.h
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h
nntrainer/tensor/tensor_dim.cpp
nntrainer/utils/util_func.cpp
test/nntrainer_test_util.cpp
test/unittest/layers/unittest_layers_fully_connected.cpp
test/unittest/unittest_nntrainer_activations.cpp
test/unittest/unittest_nntrainer_tensor_fp16.cpp

index e6cf58265f2eefc7108cc418077b930d1f8e76df..dd2c11286938644246ab10a1d3738a1393401d33 100644 (file)
 #include <bitset>
 #include <vector>
 
+#ifdef USE__FP16
+#define  _FP16 __fp16
+#else
+#define  _FP16 _Float16
+#endif
+
 namespace ml {
 namespace train {
 
index 85592472704640db0826f8bdfc326d27d118a6ba..d4fb43fde7b0874f52d24e12d7c5d287a2f65bb6 100644 (file)
@@ -73,9 +73,23 @@ warning_c_flags = [
 # enfif
 
 if get_option('enable-fp16')
-  extra_defines += '-DENABLE_FP16=1'
+   arch = target_machine.cpu_family()
+   extra_defines += '-DENABLE_FP16=1'
+
+   if get_option('platform') == 'android'
+     add_project_arguments('-mfp16-format=ieee', language: ['c', 'cpp'])
+     extra_defines += '-DUSE__FP16=1'
+   else
+     has_avx512fp16 = cc.has_argument('-mavx512fp16')
+     if (has_avx512fp16)
+       # add_project_arguments(['-mavx512fp16'], language: ['c','cpp'])
+       message ('Float16 for x86_64 enabled. Modern gcc-x64 genrally supports float16 with _Float16. -mavx512fp16 added for hardware acceleration')
+     else
+       warning ('Float16 for x86_64 enabled. However, software emulation is applied for fp16, making it slower and inconsistent. Use GCC 12+ for AVX512 FP16 support. This build will probably fail unless you bring a compiler that supports fp16 for x64.')
+     endif
+   endif  
 endif
-
+    
 foreach extra_arg : warning_flags
   if cc.has_argument (extra_arg)
     add_project_arguments([extra_arg], language: 'c')
index 3fbd1372f186cd5605bf504795bfa5465d8fba1d..46136e02a2116112b783c4237f3cb2b4f6aa5797 100644 (file)
@@ -34,6 +34,8 @@ void EmbeddingLayer::finalize(InitLayerContext &context) {
   NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
     << "Embedding layer takes only one input";
 
+  context.setInputDataType(TensorDim::DataType::FP32);
+  
   const TensorDim &input_dim = context.getInputDimensions()[SINGLE_INOUT_IDX];
   NNTR_THROW_IF(input_dim.channel() != 1, std::invalid_argument)
     << "Embedding layer takes only one for channel size";
@@ -53,9 +55,13 @@ void EmbeddingLayer::finalize(InitLayerContext &context) {
 
   output_dim.height(input_dim.width());
   output_dim.width(out_dim);
+  output_dim.setTensorType({context.getFormat(), context.getActivationDataType()});
   context.setOutputDimensions({output_dim});
 
   TensorDim dim = output_dim;
+
+  dim.setTensorType({context.getFormat(), context.getWeightDataType()});
+
   dim.height(in_dim);
   dim.width(out_dim);
   dim.batch(1);
index 5417e44f1cbd77bd23a48e8b31f6c7482608acf6..70a056b185d649fe0f4bfd0863fc4d5102ea918a 100644 (file)
@@ -67,7 +67,13 @@ void InputLayer::exportTo(Exporter &exporter,
 }
 
 void InputLayer::finalize(InitLayerContext &context) {
-  context.setOutputDimensions(context.getInputDimensions());
+
+  std::vector<TensorDim> output_dims = context.getInputDimensions();
+
+  for (auto d : output_dims)
+    d.setTensorType({context.getFormat(), context.getActivationDataType()});
+
+  context.setOutputDimensions(output_dims);
 }
 
 } /* namespace nntrainer */
index 6d55d66b1cd47b063ab3dd3939d316e329dedc82..7ec07aa693c0485c9e807c6ccbd0544b895c2be6 100644 (file)
@@ -522,9 +522,9 @@ bool RunLayerContext::validate(bool skip_input, bool skip_label) {
         } else if (val->getVariableRef().getTensorType().data_type ==
                    TensorDim::DataType::FP32) {
           tensor_map[val->getName()] =
-            val->getVariableRef().getData<_Float16>();
+            val->getVariableRef().getData<_FP16>();
           tensor_map[val->getGradientName()] =
-            val->getGradientRef().getData<_Float16>();
+            val->getGradientRef().getData<_FP16>();
         }
       }
     };
index 738b385657e5a605ef2cac3ef83b8c6f03299a98..f268489b0d14a86e1f89086df390a33503cbb1f6 100644 (file)
@@ -112,6 +112,11 @@ public:
    */
   const std::vector<TensorDim> &getInputDimensions() const { return input_dim; }
 
+  void setInputDataType(TensorDim::DataType ty) {
+    for (auto d : input_dim)
+      d.setDataType(ty);
+  }
+
   /**
    * @brief Set the Dim Flag to retrieve effective dimension
    *
index 84f767f4ba09341df72d476fe2abbc882fa29e32..f17b0b41a92bb8cebea6346a6e60f74b9aaa929f 100644 (file)
 
 namespace nntrainer {
 void LossLayer::finalize(InitLayerContext &context) {
-  context.setOutputDimensions(context.getInputDimensions());
+  std::vector<TensorDim> input_dim = context.getInputDimensions();
+  std::vector<TensorDim> output_dim = input_dim;
+  for (auto d : output_dim)
+    d.setDataType(
+      str_converter<enum_class_prop_tag,
+                    nntrainer::TensorDataTypeInfo>::from_string("FP32"));
+  
+  context.setOutputDimensions(output_dim);
 }
 
 void LossLayer::updateLoss(RunLayerContext &context, const Tensor &l) {
index 418336caebc05c267d327141a1604baea6ac6aaf..df89646e965a8b6fe696e51d33f470f25debb155 100644 (file)
 
 #define sgemv_loop_fp16(ci, cj, cM, cN)           \
   do {                                       \
-    _Float16 y0;                               \
+    _FP16 y0;                               \
     unsigned int i, j;                       \
     for (ci = 0; ci != cM; ci++) {           \
-      y0 = Y[ci * incy] * static_cast<_Float16>(beta);              \
+      y0 = Y[ci * incy] * static_cast<_FP16>(beta);              \
       for (cj = 0; cj != cN; cj++)           \
         y0 += A[i + j * lda] * X[cj * incx]; \
       Y[ci * incy] = y0;                     \
 namespace nntrainer {
 
 #ifdef ENABLE_FP16
-static void saxpy_FP16(const unsigned int N, const float alpha, const _Float16 *X,
-                       const int incX, _Float16 *Y, const int incY) {
+static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X,
+                       const int incX, _FP16 *Y, const int incY) {
   if (incX < 0 or incY < 0)
     throw std::invalid_argument(
       "Error: negative inc not supported without cblas");
   for (unsigned int i = 0; i < N; ++i)
-    Y[i * incY] = Y[i * incY] + static_cast<_Float16>(alpha) * X[i * incX];
+    Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX];
 }
 
 static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
                        const unsigned int M, const unsigned int N,
-                       const float alpha, const _Float16 *A,
-                       const unsigned int lda, const _Float16 *X, const int incX,
-                       const float beta, _Float16 *Y, const int incY) {
+                       const float alpha, const _FP16 *A,
+                       const unsigned int lda, const _FP16 *X, const int incX,
+                       const float beta, _FP16 *Y, const int incY) {
 
   unsigned int incy = abs(incY);
   unsigned int incx = abs(incX);
@@ -69,18 +69,18 @@ static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
   }
 }
 
-static _Float16 sdot_FP16(const unsigned int N, const _Float16 *X,
-                        const unsigned int incX, const _Float16 *Y,
+static _FP16 sdot_FP16(const unsigned int N, const _FP16 *X,
+                        const unsigned int incX, const _FP16 *Y,
                         const unsigned int incY) {
-  _Float16 ret = 0;
+  _FP16 ret = 0;
   for (unsigned int i = 0; i < N; ++i) {
     ret += X[i * incX] * Y[i * incY];
   }
   return ret;
 }
 
-static void scopy_FP16(const unsigned int N, const _Float16 *X, const int incX,
-                       _Float16 *Y, const int incY) {
+static void scopy_FP16(const unsigned int N, const _FP16 *X, const int incX,
+                       _FP16 *Y, const int incY) {
   unsigned int incy = abs(incY);
   unsigned int incx = abs(incX);
 
@@ -88,56 +88,56 @@ static void scopy_FP16(const unsigned int N, const _Float16 *X, const int incX,
     Y[i * incy] = X[i * incx];
 }
 
-void sscal(const unsigned int N, const float alpha, _Float16 *X, const int incX) {
+void sscal(const unsigned int N, const float alpha, _FP16 *X, const int incX) {
   unsigned int incx = abs(incX);
 
   for (unsigned int i = 0; i < N; ++i)
-    X[i * incx] = static_cast<_Float16>(alpha) * X[i * incx];
+    X[i * incx] = static_cast<_FP16>(alpha) * X[i * incx];
 }
 
-static _Float16 snrm2_FP16(const unsigned int N, const _Float16 *X, const int incX) {
+static _FP16 snrm2_FP16(const unsigned int N, const _FP16 *X, const int incX) {
   unsigned int incx = abs(incX);
-  _Float16 sum = 0;
-  _Float16 tmp;
+  _FP16 sum = 0;
+  _FP16 tmp;
 #pragma omp parallel for private(tmp) reduction(+ : sum)
   for (unsigned int i = 0; i < N; i++) {
     tmp = X[i * incx];
     sum += tmp * tmp;
   }
-  return static_cast<_Float16>(sqrt(sum));
+  return static_cast<_FP16>(sqrt(sum));
 }
 static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
                        CBLAS_TRANSPOSE TransB, const unsigned int M,
                        const unsigned int N, const unsigned int K,
-                       const float alpha, const _Float16 *A,
-                       const unsigned int lda, const _Float16 *B,
-                       const unsigned int ldb, const float beta, _Float16 *C,
+                       const float alpha, const _FP16 *A,
+                       const unsigned int lda, const _FP16 *B,
+                       const unsigned int ldb, const float beta, _FP16 *C,
                        const unsigned int ldc) {
 
   for (unsigned int m = 0; m < M; ++m) {
     for (unsigned int n = 0; n < N; ++n) {
-      _Float16 c = 0;
-      _Float16 c_old = C[m * ldc + n];
+      _FP16 c = 0;
+      _FP16 c_old = C[m * ldc + n];
       for (unsigned int k = 0; k < K; ++k) {
-        _Float16 a, b;
+        _FP16 a, b;
         a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
         b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
         c += a * b;
       }
-      C[m * ldc + n] = static_cast<_Float16>(alpha) * c;
+      C[m * ldc + n] = static_cast<_FP16>(alpha) * c;
       if (beta != 0.0)
-        C[m * ldc + n] += static_cast<_Float16>(beta) * c_old;
+        C[m * ldc + n] += static_cast<_FP16>(beta) * c_old;
     }
   }
 }
 
-static unsigned int isamax_FP16(const unsigned int N, const _Float16 *X,
+static unsigned int isamax_FP16(const unsigned int N, const _FP16 *X,
                                 const int incX) {
 
   unsigned int max_idx = 0;
-  _Float16 max_val = X[0];
+  _FP16 max_val = X[0];
   for (unsigned int n = 1; n < N; n += incX) {
-    _Float16 cur_val = (X[n] >= 0) ? X[n] : -1 * X[n];
+    _FP16 cur_val = (X[n] >= 0) ? X[n] : -1 * X[n];
     if (cur_val > max_val) {
       max_val = cur_val;
       max_idx = n;
@@ -147,43 +147,43 @@ static unsigned int isamax_FP16(const unsigned int N, const _Float16 *X,
   return max_idx;
 }
 
-void saxpy(const unsigned int N, const float alpha, const _Float16 *X,
-           const int incX, _Float16 *Y, const int incY) {
+void saxpy(const unsigned int N, const float alpha, const _FP16 *X,
+           const int incX, _FP16 *Y, const int incY) {
   saxpy_FP16(N, alpha, X, incX, Y, incY);
 }
 
 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
-           const float alpha, const _Float16 *A, const unsigned int lda,
-           const _Float16 *B, const unsigned int ldb, const float beta, _Float16 *C,
+           const float alpha, const _FP16 *A, const unsigned int lda,
+           const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C,
            const unsigned int ldc) {
   sgemm_FP16(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
              ldc);
 }
 
-void scopy(const unsigned int N, const _Float16 *X, const int incX, _Float16 *Y,
+void scopy(const unsigned int N, const _FP16 *X, const int incX, _FP16 *Y,
            const int incY) {
   scopy_FP16(N, X, incX, Y, incY);
 
 } // namespace nntrainer
 
-_Float16 snrm2(const int N, const _Float16 *X, const int incX) {
+_FP16 snrm2(const int N, const _FP16 *X, const int incX) {
   return snrm2_FP16(N, X, incX);
 }
 
-_Float16 sdot(const unsigned int N, const _Float16 *X, const unsigned int incX,
-            const _Float16 *Y, const unsigned int incY) {
+_FP16 sdot(const unsigned int N, const _FP16 *X, const unsigned int incX,
+            const _FP16 *Y, const unsigned int incY) {
   return sdot_FP16(N, X, incX, Y, incY);
 }
 
 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
-           const unsigned int N, const float alpha, const _Float16 *A,
-           const unsigned int lda, const _Float16 *X, const int incX,
-           const float beta, _Float16 *Y, const int incY) {
+           const unsigned int N, const float alpha, const _FP16 *A,
+           const unsigned int lda, const _FP16 *X, const int incX,
+           const float beta, _FP16 *Y, const int incY) {
   sgemv_FP16(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
 }
 
-unsigned int isamax(const unsigned int N, const _Float16 *X, const int incX) {
+unsigned int isamax(const unsigned int N, const _FP16 *X, const int incX) {
   /// @todo isamax_FP16 for BLAS_NUM_THREADS
   return isamax_FP16(N, X, incX);
 }
@@ -310,7 +310,7 @@ void sscal(const unsigned int N, const float alpha, void *X, const int incX,
     sscal_raw(N, alpha, (float *)X, incX);
   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    sscal(N, alpha, (_Float16 *)X, incX);
+    sscal(N, alpha, (_FP16 *)X, incX);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -344,8 +344,8 @@ void saxpy(const unsigned int N, const float alpha, const void *X,
               static_cast<float *>(Y), incY);
   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    saxpy_FP16(N, alpha, static_cast<const _Float16 *>(X), incX,
-               static_cast<_Float16 *>(Y), incY);
+    saxpy_FP16(N, alpha, static_cast<const _FP16 *>(X), incX,
+               static_cast<_FP16 *>(Y), incY);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -411,9 +411,9 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
     sgemm_FP16(order, TransA, TransB, M, N, K, alpha,
-               static_cast<const _Float16 *>(A), lda,
-               static_cast<const _Float16 *>(B), ldb, beta,
-               static_cast<_Float16 *>(C), ldc);
+               static_cast<const _FP16 *>(A), lda,
+               static_cast<const _FP16 *>(B), ldb, beta,
+               static_cast<_FP16 *>(C), ldc);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -479,7 +479,7 @@ void scopy(const unsigned int N, const void *X, const int incX, void *Y,
     scopy_raw(N, (float *)X, incX, (float *)Y, incY);
   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    scopy_FP16(N, (_Float16 *)X, incX, (_Float16 *)Y, incY);
+    scopy_FP16(N, (_FP16 *)X, incX, (_FP16 *)Y, incY);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -542,9 +542,9 @@ void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
   } else if (d_type == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
     return sgemv_FP16(order, TransA, M, N, alpha,
-                      static_cast<const _Float16 *>(A), lda,
-                      static_cast<const _Float16 *>(X), incX, beta,
-                      static_cast<_Float16 *>(Y), incY);
+                      static_cast<const _FP16 *>(A), lda,
+                      static_cast<const _FP16 *>(X), incX, beta,
+                      static_cast<_FP16 *>(Y), incY);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
index 44102aad4f04b5bdd8f6efe3fb8b241f04710720..0193e4d693da64b6d3df7ebdf9dfe00548105122 100644 (file)
@@ -40,24 +40,24 @@ enum CBLAS_TRANSPOSE {
 namespace nntrainer {
 
 #ifdef ENABLE_FP16
-void sscal(const unsigned int N, const float alpha, _Float16 *X, const int incX);
-_Float16 snrm2(const int N, const _Float16 *X, const int incX);
-void scopy(const unsigned int N, const _Float16 *X, const int incX, _Float16 *Y,
+void sscal(const unsigned int N, const float alpha, _FP16 *X, const int incX);
+_FP16 snrm2(const int N, const _FP16 *X, const int incX);
+void scopy(const unsigned int N, const _FP16 *X, const int incX, _FP16 *Y,
            const int intY);
-_Float16 sdot(const unsigned int N, const _Float16 *X, const unsigned int incX,
-            const _Float16 *Y, const unsigned int incY);
-void saxpy(const unsigned int N, const float alpha, const _Float16 *X,
-           const int incX, _Float16 *Y, const int incY);
+_FP16 sdot(const unsigned int N, const _FP16 *X, const unsigned int incX,
+            const _FP16 *Y, const unsigned int incY);
+void saxpy(const unsigned int N, const float alpha, const _FP16 *X,
+           const int incX, _FP16 *Y, const int incY);
 void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
            const unsigned int M, const unsigned int N, const unsigned int K,
-           const float alpha, const _Float16 *A, const unsigned int lda,
-           const _Float16 *B, const unsigned int ldb, const float beta, _Float16 *C,
+           const float alpha, const _FP16 *A, const unsigned int lda,
+           const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C,
            const unsigned int ldc);
 void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
-           const unsigned int N, const float alpha, const _Float16 *A,
-           const unsigned int lda, const _Float16 *X, const int incX,
-           const float beta, _Float16 *Y, const int incY);
-unsigned int isamax(const unsigned int N, const _Float16 *X, const int incX);
+           const unsigned int N, const float alpha, const _FP16 *A,
+           const unsigned int lda, const _FP16 *X, const int incX,
+           const float beta, _FP16 *Y, const int incY);
+unsigned int isamax(const unsigned int N, const _FP16 *X, const int incX);
 #endif
 
 void sscal(const unsigned int N, const float alpha, void *X, const int incX,
index c2eddc336951664489bdfdf4283b5116ed16dd74..221b21ac9fea01f875bcb9817869506520aa8348 100644 (file)
@@ -174,9 +174,9 @@ void Tensor::allocate() {
 
     } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-      mem_data = new MemoryData((void *)(new _Float16[dim.getDataLen()]{}));
+      mem_data = new MemoryData((void *)(new _FP16[dim.getDataLen()]{}));
       data = std::shared_ptr<MemoryData>(mem_data, [](auto *mem_data) {
-        delete[] mem_data->template getAddr<_Float16>();
+        delete[] mem_data->template getAddr<_FP16>();
         delete mem_data;
       });
 #else
@@ -216,8 +216,8 @@ bool Tensor::operator==(const Tensor &rhs) const {
     }
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *_data = getData<_Float16>();
-    const _Float16 *_rdata = rhs.getData<_Float16>();
+    const _FP16 *_data = getData<_FP16>();
+    const _FP16 *_rdata = rhs.getData<_FP16>();
     for (size_t i = 0; i < len; ++i) {
       // @todo: need to check if float casting valid
       if ((std::isnan((float)_data[i]) && !std::isnan((float)_rdata[i])) ||
@@ -238,8 +238,8 @@ void Tensor::setRandNormal(float mean, float std) {
     setDist<float, std::normal_distribution<float>>(
       std::normal_distribution<float>(mean, std));
   } else if (this->getDataType() == ml::train::TensorDim::DataType::FP16) {
-    throw std::invalid_argument(
-      "_Float16 is not supported by std::normal_distribution");
+    setDist<_FP16, std::normal_distribution<float>>(
+      std::normal_distribution<float>(mean, std));
   }
 }
 
@@ -248,8 +248,8 @@ void Tensor::setRandUniform(float min, float max) {
     setDist<float, std::uniform_real_distribution<float>>(
       std::uniform_real_distribution<float>(min, max));
   } else if (this->getDataType() == ml::train::TensorDim::DataType::FP16) {
-    throw std::invalid_argument(
-      "_Float16 is not supported by std::uniform_real_distribution");
+    setDist<_FP16, std::uniform_real_distribution<float>>(
+      std::uniform_real_distribution<float>(min, max));
   }
 }
 
@@ -259,8 +259,8 @@ void Tensor::setRandBernoulli(float probability) {
       std::bernoulli_distribution(probability));
   } else if (this->getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    setDist<_Float16, std::bernoulli_distribution>(
-      std::bernoulli_distribution((_Float16)probability));
+    setDist<_FP16, std::bernoulli_distribution>(
+      std::bernoulli_distribution((_FP16)probability));
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -359,11 +359,11 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
       << output.getName() << " is not allocated";
   } else if (getDataType() == Tdatatype::FP16) {
 #ifdef ENABLE_FP16
-    NNTR_THROW_IF(getData<_Float16>() == nullptr, std::invalid_argument)
+    NNTR_THROW_IF(getData<_FP16>() == nullptr, std::invalid_argument)
       << getName() << " is not allocated";
-    NNTR_THROW_IF(m.getData<_Float16>() == nullptr, std::invalid_argument)
+    NNTR_THROW_IF(m.getData<_FP16>() == nullptr, std::invalid_argument)
       << m.getName() << " is not allocated";
-    NNTR_THROW_IF(output.getData<_Float16>() == nullptr, std::invalid_argument)
+    NNTR_THROW_IF(output.getData<_FP16>() == nullptr, std::invalid_argument)
       << output.getName() << " is not allocated";
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
@@ -411,8 +411,8 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
             for (unsigned int h = 0; h < height(); ++h) {
               for (unsigned int w = 0; w < width(); ++w) {
                 output.addValue(b, c, h, w,
-                                getValue<_Float16>(b, c, h, w) *
-                                  m.getValue<_Float16>(b, c, h, w),
+                                getValue<_FP16>(b, c, h, w) *
+                                  m.getValue<_FP16>(b, c, h, w),
                                 beta);
               }
             }
@@ -422,11 +422,11 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
         for (unsigned int b = 0; b < batch(); ++b) {
           for (unsigned int c = 0; c < channel(); ++c) {
             for (unsigned int h = 0; h < height(); ++h) {
-              _Float16 *out_data = output.getAddress<_Float16>(b, c, h, 0);
-              const _Float16 *m_data = m.getAddress<_Float16>(b, c, h, 0);
-              const _Float16 *in_data = getAddress<_Float16>(b, c, h, 0);
+              _FP16 *out_data = output.getAddress<_FP16>(b, c, h, 0);
+              const _FP16 *m_data = m.getAddress<_FP16>(b, c, h, 0);
+              const _FP16 *in_data = getAddress<_FP16>(b, c, h, 0);
               std::transform(in_data, in_data + width(), m_data, out_data,
-                             std::multiplies<_Float16>());
+                             std::multiplies<_FP16>());
             }
           }
         }
@@ -475,8 +475,8 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
             for (unsigned int w = 0; w < width(); ++w) {
               for (unsigned int c = 0; c < channel(); ++c) {
                 output.addValue(b, c, h, w,
-                                getValue<_Float16>(b, c, h, w) *
-                                  m.getValue<_Float16>(b, c, h, w),
+                                getValue<_FP16>(b, c, h, w) *
+                                  m.getValue<_FP16>(b, c, h, w),
                                 beta);
               }
             }
@@ -488,11 +488,11 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
         for (unsigned int b = 0; b < batch(); ++b) {
           for (unsigned int h = 0; h < height(); ++h) {
             for (unsigned int w = 0; w < width(); ++w) {
-              _Float16 *out_data = output.getAddress<_Float16>(b, 0, h, w);
-              const _Float16 *m_data = m.getAddress<_Float16>(b, 0, h, w);
-              const _Float16 *in_data = getAddress<_Float16>(b, 0, h, w);
+              _FP16 *out_data = output.getAddress<_FP16>(b, 0, h, w);
+              const _FP16 *m_data = m.getAddress<_FP16>(b, 0, h, w);
+              const _FP16 *in_data = getAddress<_FP16>(b, 0, h, w);
               std::transform(in_data, in_data + channel(), m_data, out_data,
-                             std::multiplies<_Float16>());
+                             std::multiplies<_FP16>());
             }
           }
         }
@@ -540,11 +540,11 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
       << output.getName() << " is not allocated";
   } else if (getDataType() == Tdatatype::FP16) {
 #ifdef ENABLE_FP16
-    NNTR_THROW_IF(getData<_Float16>() == nullptr, std::invalid_argument)
+    NNTR_THROW_IF(getData<_FP16>() == nullptr, std::invalid_argument)
       << getName() << " is not allocated";
-    NNTR_THROW_IF(m.getData<_Float16>() == nullptr, std::invalid_argument)
+    NNTR_THROW_IF(m.getData<_FP16>() == nullptr, std::invalid_argument)
       << m.getName() << " is not allocated";
-    NNTR_THROW_IF(output.getData<_Float16>() == nullptr, std::invalid_argument)
+    NNTR_THROW_IF(output.getData<_FP16>() == nullptr, std::invalid_argument)
       << output.getName() << " is not allocated";
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
@@ -590,8 +590,8 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
             for (unsigned int h = 0; h < height(); ++h) {
               for (unsigned int w = 0; w < width(); ++w) {
                 output.setValue(b, c, h, w,
-                                getValue<_Float16>(b, c, h, w) +
-                                  m.getValue<_Float16>(b, c, h, w) * beta);
+                                getValue<_FP16>(b, c, h, w) +
+                                  m.getValue<_FP16>(b, c, h, w) * beta);
               }
             }
           }
@@ -600,11 +600,11 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
         for (unsigned int b = 0; b < batch(); ++b) {
           for (unsigned int c = 0; c < channel(); ++c) {
             for (unsigned int h = 0; h < height(); ++h) {
-              _Float16 *out_data = output.getAddress<_Float16>(b, c, h, 0);
-              const _Float16 *m_data = m.getAddress<_Float16>(b, c, h, 0);
-              const _Float16 *in_data = getAddress<_Float16>(b, c, h, 0);
+              _FP16 *out_data = output.getAddress<_FP16>(b, c, h, 0);
+              const _FP16 *m_data = m.getAddress<_FP16>(b, c, h, 0);
+              const _FP16 *in_data = getAddress<_FP16>(b, c, h, 0);
               std::transform(in_data, in_data + width(), m_data, out_data,
-                             std::plus<_Float16>());
+                             std::plus<_FP16>());
             }
           }
         }
@@ -652,8 +652,8 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
             for (unsigned int w = 0; w < width(); ++w) {
               for (unsigned int c = 0; c < channel(); ++c) {
                 output.setValue(b, c, h, w,
-                                getValue<_Float16>(b, c, h, w) +
-                                  m.getValue<_Float16>(b, c, h, w) * beta);
+                                getValue<_FP16>(b, c, h, w) +
+                                  m.getValue<_FP16>(b, c, h, w) * beta);
               }
             }
           }
@@ -664,11 +664,11 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
         for (unsigned int b = 0; b < batch(); ++b) {
           for (unsigned int h = 0; h < height(); ++h) {
             for (unsigned int w = 0; w < width(); ++w) {
-              _Float16 *out_data = output.getAddress<_Float16>(b, 0, h, w);
-              const _Float16 *m_data = m.getAddress<_Float16>(b, 0, h, w);
-              const _Float16 *in_data = getAddress<_Float16>(b, 0, h, w);
+              _FP16 *out_data = output.getAddress<_FP16>(b, 0, h, w);
+              const _FP16 *m_data = m.getAddress<_FP16>(b, 0, h, w);
+              const _FP16 *in_data = getAddress<_FP16>(b, 0, h, w);
               std::transform(in_data, in_data + channel(), m_data, out_data,
-                             std::plus<_Float16>());
+                             std::plus<_FP16>());
             }
           }
         }
@@ -694,7 +694,7 @@ int Tensor::multiply_i(float const &value) {
     sscal(len, value, data, 1);
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    _Float16 *data = getData<_Float16>();
+    _FP16 *data = getData<_FP16>();
     unsigned int len = size();
     sscal(len, value, data, 1);
 #else
@@ -711,18 +711,18 @@ Tensor Tensor::multiply(float const &value) const {
 
 Tensor &Tensor::multiply(float const &value, Tensor &out) const {
   /// @todo add unittest
-  if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+  // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
     auto f = std::bind(std::multiplies<float>(), std::placeholders::_1, value);
     return apply(f, out);
-  } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-#ifdef ENABLE_FP16
-    auto f = std::bind(std::multiplies<_Float16>(), std::placeholders::_1,
-                       static_cast<_Float16>(value));
-    return apply(f, out);
-#else
-    throw std::invalid_argument("Error: enable-fp16 is not enabled");
-#endif
-  }
+//   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+// #ifdef ENABLE_FP16
+//     auto f = std::bind(std::multiplies<_FP16>(), std::placeholders::_1,
+//                        static_cast<_FP16>(value));
+//     return apply(f, out);
+// #else
+//     throw std::invalid_argument("Error: enable-fp16 is not enabled");
+// #endif
+  // }
   return out;
 }
 
@@ -783,15 +783,15 @@ Tensor &Tensor::multiply(Tensor const &m, Tensor &output,
 
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    auto f = [&](const BroadcastInfo &e, const _Float16 *buf,
-                 const _Float16 *m_buf, _Float16 *out_buf) {
+    auto f = [&](const BroadcastInfo &e, const _FP16 *buf,
+                 const _FP16 *m_buf, _FP16 *out_buf) {
       if (e.strides[3] == 1 && output.strides[3] == 1 && strides[3] == 1 &&
           beta == 0.0) {
         std::transform(buf, buf + e.buffer_size, m_buf, out_buf,
-                       std::multiplies<_Float16>());
+                       std::multiplies<_FP16>());
       } else {
         for (unsigned int i = 0; i < e.buffer_size; ++i) {
-          *out_buf = *buf * *m_buf + static_cast<_Float16>(beta) * *out_buf;
+          *out_buf = *buf * *m_buf + static_cast<_FP16>(beta) * *out_buf;
           buf += strides[3];
           m_buf += e.strides[3];
           out_buf += output.strides[3];
@@ -831,24 +831,24 @@ Tensor Tensor::divide(float const &value) const {
 }
 
 Tensor &Tensor::divide(float const &value, Tensor &out) const {
-  /// @todo add unittest, _Float16 ZeroDivisionError
+  /// @todo add unittest, _FP16 ZeroDivisionError
   if (value == 0.0f) {
     std::stringstream ss;
     ss << "[Tensor] divide by value failed, value: " << value;
     throw std::invalid_argument(ss.str().c_str());
   }
 
-  if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+  // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
     auto f = std::bind(std::divides<float>(), std::placeholders::_1, value);
     return apply(f, out);
-  } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-#ifdef ENABLE_FP16
-    auto f = std::bind(std::divides<_Float16>(), std::placeholders::_1, static_cast<_Float16>(value));
-    return apply(f, out);
-#else
-    throw std::invalid_argument("Error: enable-fp16 is not enabled");
-#endif
-  }
+//   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+// #ifdef ENABLE_FP16
+//     auto f = std::bind(std::divides<_FP16>(), std::placeholders::_1, static_cast<_FP16>(value));
+//     return apply(f, out);
+// #else
+//     throw std::invalid_argument("Error: enable-fp16 is not enabled");
+// #endif
+//   }
   return out;
 }
 
@@ -892,11 +892,11 @@ Tensor &Tensor::divide(Tensor const &m, Tensor &output) const {
     apply_broadcast(m, f, output);
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    auto f = [&](const BroadcastInfo &e, const _Float16 *buf,
-                 const _Float16 *m_buf, _Float16 *out_buf) {
+    auto f = [&](const BroadcastInfo &e, const _FP16 *buf,
+                 const _FP16 *m_buf, _FP16 *out_buf) {
       if (e.strides[3] == 1 && output.strides[3] == 1 && strides[3] == 1) {
         std::transform(buf, buf + e.buffer_size, m_buf, out_buf,
-                       std::divides<_Float16>());
+                       std::divides<_FP16>());
       } else {
         for (unsigned int i = 0; i < e.buffer_size; ++i) {
           *out_buf = *buf / *m_buf;
@@ -931,18 +931,18 @@ Tensor Tensor::add(float const &value) const {
 
 Tensor &Tensor::add(float const &value, Tensor &out) const {
   /// @todo add unittest
-  if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+  // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
     auto f = std::bind(std::plus<float>(), std::placeholders::_1, value);
     return apply(f, out);
-  } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-#ifdef ENABLE_FP16
-    auto f = std::bind(std::plus<_Float16>(), std::placeholders::_1,
-                       static_cast<_Float16>(value));
-    return apply(f, out);
-#else
-    throw std::invalid_argument("Error: enable-fp16 is not enabled");
-#endif
-  }
+//   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+// #ifdef ENABLE_FP16
+//     auto f = std::bind(std::plus<_FP16>(), std::placeholders::_1,
+//                        static_cast<_FP16>(value));
+//     return apply(f, out);
+// #else
+//     throw std::invalid_argument("Error: enable-fp16 is not enabled");
+// #endif
+//   }
   return out;
 }
 
@@ -968,10 +968,10 @@ int Tensor::add_i(Tensor const &m, float const alpha) {
 
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    auto f = [&](const BroadcastInfo &e, const _Float16 *buf,
-                 const _Float16 *m_buf, _Float16 *out_buf) {
+    auto f = [&](const BroadcastInfo &e, const _FP16 *buf,
+                 const _FP16 *m_buf, _FP16 *out_buf) {
       saxpy(e.buffer_size, alpha, m_buf, e.strides[3], out_buf, strides[3]);
-      /// @todo: saxpy is not valid for _Float16
+      /// @todo: saxpy is not valid for _FP16
     };
 
     /// @todo: enable this after add_strided supports broadcast
@@ -1022,15 +1022,15 @@ Tensor &Tensor::add(Tensor const &m, Tensor &output, float const alpha) const {
     apply_broadcast(m, f, output);
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    auto f = [&](const BroadcastInfo &e, const _Float16 *buf,
-                 const _Float16 *m_buf, _Float16 *out_buf) {
+    auto f = [&](const BroadcastInfo &e, const _FP16 *buf,
+                 const _FP16 *m_buf, _FP16 *out_buf) {
       if (e.strides[3] == 1 && strides[3] == 1 && strides[3] == 1 &&
           alpha == 0) {
         std::transform(buf, buf + e.buffer_size, m_buf, out_buf,
-                       std::plus<_Float16>());
+                       std::plus<_FP16>());
       } else {
         for (unsigned int i = 0; i < e.buffer_size; ++i) {
-          *out_buf = *buf + *m_buf * static_cast<_Float16>(alpha);
+          *out_buf = *buf + *m_buf * static_cast<_FP16>(alpha);
           buf += strides[3];
           m_buf += e.strides[3];
           out_buf += strides[3];
@@ -1057,18 +1057,18 @@ Tensor Tensor::subtract(float const &value) const {
 
 Tensor &Tensor::subtract(float const &value, Tensor &out) const {
   /// @todo add unittest
-  if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+  // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
     auto f = std::bind(std::minus<float>(), std::placeholders::_1, value);
     return apply(f, out);
-  } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-#ifdef ENABLE_FP16
-    auto f = std::bind(std::minus<_Float16>(), std::placeholders::_1,
-                       static_cast<_Float16>(value));
-    return apply(f, out);
-#else
-    ml_loge("%s", "Error: enable-fp16 is not enabled");
-#endif
-  }
+//   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+// #ifdef ENABLE_FP16
+//     auto f = std::bind(std::minus<_FP16>(), std::placeholders::_1,
+//                        static_cast<_FP16>(value));
+//     return apply(f, out);
+// #else
+//     ml_loge("%s", "Error: enable-fp16 is not enabled");
+// #endif
+//   }
   return out; // shouldn't reach
 }
 
@@ -1091,21 +1091,21 @@ Tensor Tensor::pow(float exponent) const {
 }
 
 Tensor &Tensor::pow(float exponent, Tensor &out) const {
-  if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+  // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
     auto f = [exponent](float in) { return powf(in, exponent); };
     return apply(f, out);
-  }
-  if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-#ifdef ENABLE_FP16
-    auto f = [exponent](_Float16 in) {
-      return static_cast<_Float16>(powf(in, exponent));
-    };
-    return apply(f, out);
-#else
-    ml_loge("%s", "Error: enable-fp16 is not enabled");
-#endif
-  }
-  return out;
+  // }
+//   if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+// #ifdef ENABLE_FP16
+//     auto f = [exponent](_FP16 in) {
+//       return static_cast<_FP16>(powf(in, exponent));
+//     };
+//     return apply(f, out);
+// #else
+//     ml_loge("%s", "Error: enable-fp16 is not enabled");
+// #endif
+//   }
+  // return out;
 }
 
 Tensor Tensor::getBatchSlice(size_t offset, unsigned int size) const {
@@ -1311,10 +1311,10 @@ std::vector<Tensor> Tensor::split(std::vector<size_t> sizes, int axis) {
     auto iter_value =
       [this, is_format_nchw](
         std::array<size_t, 4> &loc, const std::array<size_t, 4> &end_loc,
-        const std::array<size_t, 4> &reset_dim_arr) -> _Float16 & {
+        const std::array<size_t, 4> &reset_dim_arr) -> _FP16 & {
       auto &value = (is_format_nchw)
-                      ? getValue<_Float16>(loc[0], loc[1], loc[2], loc[3])
-                      : getValue<_Float16>(loc[0], loc[3], loc[1], loc[2]);
+                      ? getValue<_FP16>(loc[0], loc[1], loc[2], loc[3])
+                      : getValue<_FP16>(loc[0], loc[3], loc[1], loc[2]);
       for (int i = 3; i >= 0; --i) {
         loc[i]++;
         if (loc[i] == end_loc[i]) {
@@ -1380,7 +1380,7 @@ std::vector<Tensor> Tensor::split(std::vector<size_t> sizes, int axis) {
                          ret_dims[i].width(), ret_dims[i].channel()};
       }
 
-      ret_t.apply_i([&iter_value, &loc, &end_loc, &reset_dim_arr](_Float16 _) {
+      ret_t.apply_i([&iter_value, &loc, &end_loc, &reset_dim_arr](float _) {
         return iter_value(loc, end_loc, reset_dim_arr);
       });
     }
@@ -1487,10 +1487,10 @@ Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
     auto iter_value =
       [is_format_nchw](
         std::array<unsigned, 4> &loc, const std::array<unsigned, 4> &start_loc,
-        Tensor &t, const std::array<unsigned, 4> &ref_dim_arr) -> _Float16 & {
+        Tensor &t, const std::array<unsigned, 4> &ref_dim_arr) -> _FP16 & {
       auto &value = is_format_nchw
-                      ? t.getValue<_Float16>(loc[0], loc[1], loc[2], loc[3])
-                      : t.getValue<_Float16>(loc[0], loc[3], loc[1], loc[2]);
+                      ? t.getValue<_FP16>(loc[0], loc[1], loc[2], loc[3])
+                      : t.getValue<_FP16>(loc[0], loc[3], loc[1], loc[2]);
 
       for (int i = 3; i >= 0; --i) {
         loc[i]++;
@@ -1526,7 +1526,7 @@ Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
 
       for (size_t i = 0u, sz = t.size(); i < sz; ++i) {
         iter_value(loc, start_loc, ret, tensor_dim_arr) =
-          t.getValue<_Float16>(i);
+          t.getValue<_FP16>(i);
       }
 
       if (is_format_nchw) {
@@ -1598,17 +1598,17 @@ void Tensor::apply_broadcast(
 #ifdef ENABLE_FP16
 void Tensor::apply_broadcast(
   Tensor const &m,
-  std::function<void(const BroadcastInfo &e, const _Float16 *, const _Float16 *,
-                     _Float16 *)>
+  std::function<void(const BroadcastInfo &e, const _FP16 *, const _FP16 *,
+                     _FP16 *)>
     v_func,
   Tensor &output) const {
   CREATE_IF_EMPTY_DIMS(output, dim, nullptr);
 
-  NNTR_THROW_IF(getData<_Float16>() == nullptr, std::invalid_argument)
+  NNTR_THROW_IF(getData<_FP16>() == nullptr, std::invalid_argument)
     << getName() << " is not allocated";
-  NNTR_THROW_IF(m.getData<_Float16>() == nullptr, std::invalid_argument)
+  NNTR_THROW_IF(m.getData<_FP16>() == nullptr, std::invalid_argument)
     << m.getName() << " is not allocated";
-  NNTR_THROW_IF(output.getData<_Float16>() == nullptr, std::invalid_argument)
+  NNTR_THROW_IF(output.getData<_FP16>() == nullptr, std::invalid_argument)
     << output.getName() << " is not allocated";
 
   /// shortcut to cover when dimension matches
@@ -1618,8 +1618,8 @@ void Tensor::apply_broadcast(
     BroadcastInfo e;
     e.buffer_size = size();
     e.strides[3] = 1;
-    v_func(e, getData<_Float16>(), m.getData<_Float16>(),
-           output.getData<_Float16>());
+    v_func(e, getData<_FP16>(), m.getData<_FP16>(),
+           output.getData<_FP16>());
     return;
   }
 
@@ -1628,15 +1628,15 @@ void Tensor::apply_broadcast(
 
 void Tensor::apply_broadcast_util(
   Tensor const &m,
-  std::function<void(const BroadcastInfo &e, const _Float16 *, const _Float16 *,
-                     _Float16 *)>
+  std::function<void(const BroadcastInfo &e, const _FP16 *, const _FP16 *,
+                     _FP16 *)>
     v_func,
   Tensor &output, const BroadcastInfo &e, int cur_axis, size_t offset,
   size_t m_offset) const {
 
-  const _Float16 *buf = this->getData<_Float16>();
-  const _Float16 *m_buf = m.getData<_Float16>();
-  _Float16 *out_buf = output.getData<_Float16>();
+  const _FP16 *buf = this->getData<_FP16>();
+  const _FP16 *m_buf = m.getData<_FP16>();
+  _FP16 *out_buf = output.getData<_FP16>();
 
   if (e.buffer_axis == cur_axis) {
     v_func(e, buf + offset, m_buf + m_offset, out_buf + offset);
@@ -1708,13 +1708,13 @@ Tensor Tensor::sum_by_batch() const {
           ones.getData<float>(), 1, 0.0, rdata, 1);
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *data = getData<_Float16>();
-    _Float16 *rdata = ret.getData<_Float16>();
+    const _FP16 *data = getData<_FP16>();
+    _FP16 *rdata = ret.getData<_FP16>();
 
     Tensor ones(1, 1, 1, feat_len, this->getTensorType());
-    ones.setValue((_Float16)1.0);
+    ones.setValue((_FP16)1.0);
     sgemv(CblasRowMajor, CblasNoTrans, batch, feat_len, 1, data, feat_len,
-          ones.getData<_Float16>(), 1, 0.0, rdata, 1);
+          ones.getData<_FP16>(), 1, 0.0, rdata, 1);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -1843,7 +1843,7 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *data = getData<_Float16>();
+    const _FP16 *data = getData<_FP16>();
 
     NNTR_THROW_IF(!contiguous, std::invalid_argument)
       << getName() << " is not contiguous, cannot sum";
@@ -1853,7 +1853,7 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
 
     if (dim.getDim()[axis] == 1 and alpha == 1.0 and !beta) {
       CREATE_IF_EMPTY_DIMS(ret, dim);
-      ret.copy(this->getData<_Float16>());
+      ret.copy(this->getData<_FP16>());
       return ret;
     }
 
@@ -1866,7 +1866,7 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
       Tensor ones(1, 1, 1, batch, this->getTensorType());
       ones.setValue(alpha);
       sgemv(CblasRowMajor, CblasTrans, batch, feat_len, 1, data, feat_len,
-            ones.getData<_Float16>(), 1, beta, ret.getData<_Float16>(), 1);
+            ones.getData<_FP16>(), 1, beta, ret.getData<_FP16>(), 1);
     } break;
     case 1: {
       CREATE_IF_EMPTY_DIMS(ret, dim[0], 1, dim[2], dim[3], getTensorType());
@@ -1876,17 +1876,17 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
         Tensor ones(1, 1, 1, n, this->getTensorType());
         ones.setValue(alpha);
         sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n,
-              ones.getData<_Float16>(), 1, beta, ret.getData<_Float16>(), 1);
+              ones.getData<_FP16>(), 1, beta, ret.getData<_FP16>(), 1);
       } else {
         unsigned int feat_len = dim[2] * dim[3];
         unsigned int t_axis = dim[1];
         Tensor ones(1, 1, 1, t_axis, getTensorType());
         ones.setValue(alpha);
-        _Float16 *rdata = ret.getData<_Float16>();
+        _FP16 *rdata = ret.getData<_FP16>();
         for (unsigned int k = 0; k < dim[0]; ++k) {
           sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1,
                 &data[k * dim.getFeatureLen()], feat_len,
-                ones.getData<_Float16>(), 1, beta, &rdata[k * feat_len], 1);
+                ones.getData<_FP16>(), 1, beta, &rdata[k * feat_len], 1);
         }
       }
     } break;
@@ -1898,24 +1898,24 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
         unsigned int t_axis = dim[2];
         Tensor ones(1, 1, 1, t_axis, getTensorType());
         ones.setValue(alpha);
-        _Float16 *rdata = ret.getData<_Float16>();
+        _FP16 *rdata = ret.getData<_FP16>();
         for (unsigned int k = 0; k < dim[0]; ++k) {
           sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1,
                 &data[k * dim.getFeatureLen()], feat_len,
-                ones.getData<_Float16>(), 1, beta, &rdata[k * feat_len], 1);
+                ones.getData<_FP16>(), 1, beta, &rdata[k * feat_len], 1);
         }
       } else {
         unsigned int t_3 = dim[3];
         unsigned int t_axis = dim[2];
         Tensor ones(1, 1, 1, t_axis, getTensorType());
         ones.setValue(alpha);
-        _Float16 *rdata = ret.getData<_Float16>();
+        _FP16 *rdata = ret.getData<_FP16>();
         for (unsigned int k = 0; k < dim[0]; ++k) {
           for (unsigned int c = 0; c < dim[1]; ++c) {
             unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[2];
             unsigned int ridx = k * ret.dim.getFeatureLen() + c * dim[3];
             sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3,
-                  ones.getData<_Float16>(), 1, beta, &rdata[ridx], 1);
+                  ones.getData<_FP16>(), 1, beta, &rdata[ridx], 1);
           }
         }
       }
@@ -1927,13 +1927,13 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
         unsigned int t_axis = dim[3];
         Tensor ones(1, 1, 1, t_axis, getTensorType());
         ones.setValue(alpha);
-        _Float16 *rdata = ret.getData<_Float16>();
+        _FP16 *rdata = ret.getData<_FP16>();
         for (unsigned int k = 0; k < dim[0]; ++k) {
           for (unsigned int c = 0; c < dim[2]; ++c) {
             unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[1];
             unsigned int ridx = k * ret.dim.getFeatureLen() + c * dim[1];
             sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3,
-                  ones.getData<_Float16>(), 1, beta, &rdata[ridx], 1);
+                  ones.getData<_FP16>(), 1, beta, &rdata[ridx], 1);
           }
         }
       } else {
@@ -1942,7 +1942,7 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
         Tensor ones(1, 1, 1, n, getTensorType());
         ones.setValue(alpha);
         sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n,
-              ones.getData<_Float16>(), 1, beta, ret.getData<_Float16>(), 1);
+              ones.getData<_FP16>(), 1, beta, ret.getData<_FP16>(), 1);
       }
     } break;
     default:
@@ -2234,9 +2234,9 @@ Tensor &Tensor::dot(Tensor const &m, Tensor &result, bool trans, bool trans_m,
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *data = getData<_Float16>();
-    const _Float16 *mdata = m.getData<_Float16>();
-    _Float16 *rdata = result.getData<_Float16>();
+    const _FP16 *data = getData<_FP16>();
+    const _FP16 *mdata = m.getData<_FP16>();
+    _FP16 *rdata = result.getData<_FP16>();
     const float alpha = 1.0f;
     enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans;
     enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans;
@@ -2250,7 +2250,7 @@ Tensor &Tensor::dot(Tensor const &m, Tensor &result, bool trans, bool trans_m,
     /// case1: (1 * K) X (K * 1)
     if (M == 1 && N == 1) {
       *rdata =
-        sdot(K, data, 1, mdata, 1) + static_cast<_Float16>(beta) * (*rdata);
+        sdot(K, data, 1, mdata, 1) + static_cast<_FP16>(beta) * (*rdata);
     }
     /// case2: (M * K) X (K * 1)
     else if (N == 1) {
@@ -2350,8 +2350,8 @@ Tensor &Tensor::transpose(const std::string &direction, Tensor &out) const {
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *inptr = getData<_Float16>();
-    _Float16 *outptr = out.getData<_Float16>();
+    const _FP16 *inptr = getData<_FP16>();
+    _FP16 *outptr = out.getData<_FP16>();
     switch (indexI) {
     case 0:
       if (indexJ == 1) {
@@ -2432,8 +2432,8 @@ void Tensor::dropout_mask(float dropout) {
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    _Float16 scale = static_cast<_Float16>(1.0 / (1 - dropout));
-    _Float16 *data_ = getData<_Float16>();
+    _FP16 scale = static_cast<_FP16>(1.0 / (1 - dropout));
+    _FP16 *data_ = getData<_FP16>();
     for (unsigned int i = 0; i < size(); ++i) {
       if (data_[i] >= dropout)
         data_[i] = scale;
@@ -2467,9 +2467,9 @@ void Tensor::filter_mask(const Tensor &mask_len, bool reverse) {
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
     for (unsigned int b = 0; b < batch(); b++) {
-      _Float16 *addr = getAddress<_Float16>(b, 0, 0, 0);
+      _FP16 *addr = getAddress<_FP16>(b, 0, 0, 0);
       const uint *mask_len_val = mask_len.getAddress<uint>(b, 0, 0, 0);
-      std::fill(addr, addr + (*mask_len_val), (_Float16)en_mask_val);
+      std::fill(addr, addr + (*mask_len_val), (_FP16)en_mask_val);
     }
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
@@ -2504,17 +2504,17 @@ void Tensor::zoneout_mask(Tensor &opposite, float zoneout) {
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    _Float16 zoneout_fp16 = (_Float16)zoneout;
+    _FP16 zoneout_fp16 = (_FP16)zoneout;
     opposite.setRandBernoulli(zoneout_fp16);
 
-    _Float16 *data = getData<_Float16>();
-    _Float16 *opposite_data = opposite.getData<_Float16>();
+    _FP16 *data = getData<_FP16>();
+    _FP16 *opposite_data = opposite.getData<_FP16>();
 
     for (unsigned int i = 0; i < size(); ++i) {
       if (opposite_data[i] > epsilon) {
-        data[i] = (_Float16)0.0;
+        data[i] = (_FP16)0.0;
       } else {
-        data[i] = (_Float16)1.0;
+        data[i] = (_FP16)1.0;
       }
     }
 #else
@@ -2630,7 +2630,7 @@ void Tensor::print(std::ostream &out) const {
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *data = getData<_Float16>();
+    const _FP16 *data = getData<_FP16>();
     unsigned int len = size();
     out << "data addr: " << data << '\n';
     out << dim;
@@ -2651,7 +2651,7 @@ void Tensor::print(std::ostream &out) const {
           for (unsigned int i = 0; i < height(); i++) {
             for (unsigned int j = 0; j < width(); j++) {
               out << std::setw(10) << std::setprecision(10)
-                  << (float)this->getValue<_Float16>(k, l, i, j) << " ";
+                  << (float)this->getValue<_FP16>(k, l, i, j) << " ";
             }
             out << std::endl;
           }
@@ -2665,7 +2665,7 @@ void Tensor::print(std::ostream &out) const {
           for (unsigned int j = 0; j < width(); j++) {
             for (unsigned int l = 0; l < channel(); l++) {
               out << std::setw(10) << std::setprecision(10)
-                  << (float)this->getValue<_Float16>(k, l, i, j) << " ";
+                  << (float)this->getValue<_FP16>(k, l, i, j) << " ";
             }
             out << std::endl;
           }
@@ -2778,7 +2778,7 @@ void Tensor::copy(const void *buf) {
 
   if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    if (buf == getData<_Float16>()) {
+    if (buf == getData<_FP16>()) {
       return;
     }
 #else
@@ -2795,7 +2795,7 @@ void Tensor::copy(const void *buf) {
 
   if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    scopy(size(), (_Float16 *)buf, 1, getData<_Float16>(), 1);
+    scopy(size(), (_FP16 *)buf, 1, getData<_FP16>(), 1);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -2823,7 +2823,7 @@ void Tensor::copy_with_stride(const Tensor &from) {
         for (unsigned int c = 0; c < channel(); ++c) {
           for (unsigned int h = 0; h < height(); ++h) {
             for (unsigned int w = 0; w < width(); ++w) {
-              setValue(b, c, h, w, from.getValue<_Float16>(b, c, h, w));
+              setValue(b, c, h, w, from.getValue<_FP16>(b, c, h, w));
             }
           }
         }
@@ -2850,7 +2850,7 @@ void Tensor::copy_with_stride(const Tensor &from) {
         for (unsigned int c = 0; c < channel(); ++c) {
           for (unsigned int h = 0; h < height(); ++h) {
             for (unsigned int w = 0; w < width(); ++w) {
-              setValue(b, c, h, w, from.getValue<_Float16>(b, c, h, w));
+              setValue(b, c, h, w, from.getValue<_FP16>(b, c, h, w));
             }
           }
         }
@@ -2876,7 +2876,7 @@ void Tensor::copy(const Tensor &from) {
       copy(from.getData());
     } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-      copy(from.getData<_Float16>());
+      copy(from.getData<_FP16>());
 #else
       throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -3056,8 +3056,8 @@ void Tensor::setValue(float val) {
     std::fill(data, data + size(), val);
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    _Float16 *data = getData<_Float16>();
-    std::fill(data, data + size(), static_cast<_Float16>(val));
+    _FP16 *data = getData<_FP16>();
+    std::fill(data, data + size(), static_cast<_FP16>(val));
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -3073,9 +3073,9 @@ void Tensor::setZero() {
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
     if (contiguous)
-      sscal(size(), 0, getData<_Float16>(), 1);
+      sscal(size(), 0, getData<_FP16>(), 1);
     else
-      apply_i([](_Float16 val) -> _Float16 { return 0; });
+      apply_i([](float val) -> float { return 0; });
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
@@ -3102,7 +3102,7 @@ std::vector<unsigned int> Tensor::argmax() const {
   }
   if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *data = getData<_Float16>();
+    const _FP16 *data = getData<_FP16>();
     size_t batch_size = batch();
     size_t feature_len = dim.getFeatureLen();
 
@@ -3131,7 +3131,7 @@ float Tensor::l2norm() const {
     ret = snrm2(len, data, 1);
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *data = getData<_Float16>();
+    const _FP16 *data = getData<_FP16>();
     ret = snrm2(len, data, 1);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
@@ -3154,7 +3154,7 @@ float Tensor::max_abs() const {
 
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *data = getData<_Float16>();
+    const _FP16 *data = getData<_FP16>();
 
     unsigned int idx = isamax(len, data, 1);
     ret = *(data + idx);
@@ -3195,11 +3195,11 @@ void Tensor::normalization_i() {
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    const _Float16 *data = getData<_Float16>();
+    const _FP16 *data = getData<_FP16>();
 
     auto bounds = std::minmax_element(data, data + size());
-    const _Float16 min = *bounds.first;
-    const _Float16 max = *bounds.second;
+    const _FP16 min = *bounds.first;
+    const _FP16 max = *bounds.second;
 
     if (max == min) {
       Tensor tmp = *this;
@@ -3247,11 +3247,11 @@ void Tensor::standardization_i() {
 #ifdef ENABLE_FP16
     Tensor std_dev_by_batch(dim.batch(), 1, 1, 1);
     std_dev_by_batch.setZero();
-    _Float16 *std_dev = std_dev_by_batch.getData<_Float16>();
+    _FP16 *std_dev = std_dev_by_batch.getData<_FP16>();
 
     for (unsigned int k = 0; k < dim.batch(); ++k) {
       Tensor sub_this = this->getBatchSlice(k, 1);
-      std_dev[k] = static_cast<_Float16>(sub_this.l2norm());
+      std_dev[k] = static_cast<_FP16>(sub_this.l2norm());
     }
 
     std_dev_by_batch.divide_i(dim.getFeatureLen());
@@ -3362,7 +3362,7 @@ Tensor Tensor::rotate_180(Tensor in) {
         for (unsigned int k = 0; k < in.height(); ++k) {
           for (unsigned int l = 0; l < in.width(); ++l) {
             output.setValue(i, j, k, l,
-                            in.getValue<_Float16>(i, j, (in.height() - k - 1),
+                            in.getValue<_FP16>(i, j, (in.height() - k - 1),
                                                   (in.width() - l - 1)));
           }
         }
index 1ac798236181ed967b4204363deac2de27223b8a..bef60bfedd0583bb588dc1cf100c65185b618b4e 100644 (file)
@@ -269,7 +269,7 @@ public:
     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
 
 #ifdef ENABLE_FP16
-  Tensor(std::vector<std::vector<std::vector<std::vector<_Float16>>>> const &d,
+  Tensor(std::vector<std::vector<std::vector<std::vector<_FP16>>>> const &d,
          ml::train::TensorDim::TensorType t_type) {
 
     if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
@@ -293,9 +293,9 @@ public:
     strides = dim.computeStrides();
 
     MemoryData *mem_data =
-      new MemoryData((void *)(new _Float16[dim.getDataLen()]()));
+      new MemoryData((void *)(new _FP16[dim.getDataLen()]()));
     data = std::shared_ptr<MemoryData>(mem_data, [](MemoryData *mem_data) {
-      delete[] mem_data->getAddr<_Float16>();
+      delete[] mem_data->getAddr<_FP16>();
     });
     offset = 0;
     contiguous = true;
@@ -326,7 +326,7 @@ public:
    * @note      This constructor copies vector again. needs refactoring
    * @param[in] d data for the Tensor
    */
-  Tensor(std::vector<std::vector<std::vector<_Float16>>> const &d,
+  Tensor(std::vector<std::vector<std::vector<_FP16>>> const &d,
          ml::train::TensorDim::TensorType t_type) :
     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
 
@@ -335,7 +335,7 @@ public:
    * @note      This constructor copies vector again. needs refactoring
    * @param[in] d data for the Tensor with batch size one
    */
-  Tensor(std::vector<std::vector<_Float16>> const &d,
+  Tensor(std::vector<std::vector<_FP16>> const &d,
          ml::train::TensorDim::TensorType t_type) :
     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
 
@@ -1168,6 +1168,7 @@ public:
    * @param[out] output output tensor
    * @retval    Tensor
    */
+  
   Tensor &apply(std::function<float(float)> f, Tensor &output) const {
     CREATE_IF_EMPTY_DIMS(output, dim, nullptr);
 
@@ -1205,62 +1206,30 @@ public:
           }
         }
       }
-    } 
-    return output;
-  };
-
-  /**
-   * @brief Apply instantly to the element
-   *
-   * @param f function to apply
-   * @return int ML_ERROR_NONE if successful
-   */
-  int apply_i(std::function<_Float16(_Float16)> f) {
-    Tensor result = *this;
-    apply(f, result);
-
-    return ML_ERROR_NONE;
-  };
+    } else if (dim.getDataType() == Tdatatype::FP16) {
 
-  /**
-   * @brief     Apply function element by element
-   * @param[in] *function function pointer applied
-   * @retval    Tensor
-   */
-  Tensor apply(std::function<_Float16(_Float16)> f) const {
-    Tensor result;
-    return apply(f, result);
-  };
+      auto f_16 = [f](_FP16 x) -> _FP16 {
+        return static_cast<_FP16>(f(static_cast<float>(x)));
+      };
 
-  /**
-   * @brief     Apply function element by element
-   * @param[in] *function function pointer applied
-   * @param[out] output output tensor
-   * @retval    Tensor
-   */
-  Tensor &apply(std::function<_Float16(_Float16)> f, Tensor &output) const {
-    CREATE_IF_EMPTY_DIMS(output, dim, nullptr);
+      // std::function<_FP16(_FP16)> f_16 =
+      //   static_cast<std::function<_FP16(_FP16)>>(f);
 
-    if (dim != output.dim) {
-      /// @todo add unittest
-      throw std::invalid_argument(
-        "[Tensor::apply] output dimension does not match");
-    }
 
-    #ifdef ENABLE_FP16
+      
       if (contiguous && output.contiguous) {
-        const _Float16 *data = (getData<_Float16>());
-        _Float16 *rdata = (output.getData<_Float16>());
+        const _FP16 *data = (getData<_FP16>());
+        _FP16 *rdata = (output.getData<_FP16>());
 
-        std::transform(data, data + size(), rdata, f);
+        std::transform(data, data + size(), rdata, f_16);
       } else if (strides[3] == 1 && output.strides[3] == 1) {
         /** @todo optimize this with combining these loops where stride is 1 */
         for (unsigned int b = 0; b < batch(); ++b) {
           for (unsigned int c = 0; c < channel(); ++c) {
             for (unsigned int h = 0; h < height(); ++h) {
-              _Float16 *out_data = (_Float16 *)output.getAddress(b, c, h, 0);
-              const _Float16 *in_data = (_Float16 *)getAddress(b, c, h, 0);
-              std::transform(in_data, in_data + width(), out_data, f);
+              _FP16 *out_data = output.getAddress<_FP16>(b, c, h, 0);
+              const _FP16 *in_data = getAddress<_FP16>(b, c, h, 0);
+              std::transform(in_data, in_data + width(), out_data, f_16);
             }
           }
         }
@@ -1269,20 +1238,90 @@ public:
           for (unsigned int c = 0; c < channel(); ++c) {
             for (unsigned int h = 0; h < height(); ++h) {
               for (unsigned int w = 0; w < width(); ++w) {
-                output.setValue(b, c, h, w,
-                                f((_Float16)((_Float16)getValue(b, c, h, w))));
+                output.setValue(b, c, h, w, f_16(getValue<_FP16>(b, c, h, w)));
               }
             }
           }
         }
       }
-    #else
-      throw std::invalid_argument("Error: enable-fp16 is not enabled");
-    #endif
-  
+    }
     return output;
   };
 
+  // /**
+  //  * @brief Apply instantly to the element
+  //  *
+  //  * @param f function to apply
+  //  * @return int ML_ERROR_NONE if successful
+  //  */
+  // int apply_i(std::function<_FP16(_FP16)> f) {
+  //   Tensor result = *this;
+  //   apply(f, result);
+
+  //   return ML_ERROR_NONE;
+  // };
+
+  // /**
+  //  * @brief     Apply function element by element
+  //  * @param[in] *function function pointer applied
+  //  * @retval    Tensor
+  //  */
+  // Tensor apply(std::function<_FP16(_FP16)> f) const {
+  //   Tensor result;
+  //   return apply(f, result);
+  // };
+
+  // /**
+  //  * @brief     Apply function element by element
+  //  * @param[in] *function function pointer applied
+  //  * @param[out] output output tensor
+  //  * @retval    Tensor
+  //  */
+  // Tensor &apply(std::function<_FP16(_FP16)> f, Tensor &output) const {
+  //   CREATE_IF_EMPTY_DIMS(output, dim, nullptr);
+
+  //   if (dim != output.dim) {
+  //     /// @todo add unittest
+  //     throw std::invalid_argument(
+  //       "[Tensor::apply] output dimension does not match");
+  //   }
+
+  //   #ifdef ENABLE_FP16
+  //     if (contiguous && output.contiguous) {
+  //       const _FP16 *data = (getData<_FP16>());
+  //       _FP16 *rdata = (output.getData<_FP16>());
+
+  //       std::transform(data, data + size(), rdata, f);
+  //     } else if (strides[3] == 1 && output.strides[3] == 1) {
+  //       /** @todo optimize this with combining these loops where stride is 1 */
+  //       for (unsigned int b = 0; b < batch(); ++b) {
+  //         for (unsigned int c = 0; c < channel(); ++c) {
+  //           for (unsigned int h = 0; h < height(); ++h) {
+  //             _FP16 *out_data = (_FP16 *)output.getAddress(b, c, h, 0);
+  //             const _FP16 *in_data = (_FP16 *)getAddress(b, c, h, 0);
+  //             std::transform(in_data, in_data + width(), out_data, f);
+  //           }
+  //         }
+  //       }
+  //     } else {
+  //       for (unsigned int b = 0; b < batch(); ++b) {
+  //         for (unsigned int c = 0; c < channel(); ++c) {
+  //           for (unsigned int h = 0; h < height(); ++h) {
+  //             for (unsigned int w = 0; w < width(); ++w) {
+  //               output.setValue(b, c, h, w,
+  //                               f((_FP16)((_FP16)getValue(b, c, h, w))));
+  //             }
+  //           }
+  //         }
+  //       }
+  //     }
+  //   #else
+  //     throw std::invalid_argument("Error: enable-fp16 is not enabled");
+  //   #endif
+  
+  //   return output;
+  // };
+
   /**
    * @brief     Apply function to Tensor
    * @param[in] *function function pointer applied
@@ -1347,7 +1386,7 @@ public:
       getData<float>()[getIndex(batch, c, h, w)] = value;
     } else if (getDataType() == Tdatatype::FP16) {
 #ifdef ENABLE_FP16
-      getData<_Float16>()[getIndex(batch, c, h, w)] = static_cast<_Float16>(value);
+      getData<_FP16>()[getIndex(batch, c, h, w)] = static_cast<_FP16>(value);
 #else
       ml_loge("%s", "Error: enable-fp16 is not enabled");
 #endif
@@ -1371,8 +1410,8 @@ public:
       getData<float>()[idx] += value;
     } else if (dim.getDataType() == Tdatatype::FP16) {
 #ifdef ENABLE_FP16
-      getData<_Float16>()[idx] *= static_cast<_Float16>(beta);
-      getData<_Float16>()[idx] += static_cast<_Float16>(value);
+      getData<_FP16>()[idx] *= static_cast<_FP16>(beta);
+      getData<_FP16>()[idx] += static_cast<_FP16>(value);
 #else
       ml_loge("%s", "Error: enable-fp16 is not enabled");
 #endif
@@ -1898,16 +1937,16 @@ private:
 #ifdef ENABLE_FP16
   void apply_broadcast_util(
     Tensor const &m,
-    std::function<void(const BroadcastInfo &e, const _Float16 *, const _Float16 *,
-                       _Float16 *)>
+    std::function<void(const BroadcastInfo &e, const _FP16 *, const _FP16 *,
+                       _FP16 *)>
       v_func,
     Tensor &output, const BroadcastInfo &e, int cur_axis = -1,
     size_t offset = 0, size_t m_offset = 0) const;
 
   void
   apply_broadcast(Tensor const &m,
-                  std::function<void(const BroadcastInfo &e, const _Float16 *,
-                                     const _Float16 *, _Float16 *)>
+                  std::function<void(const BroadcastInfo &e, const _FP16 *,
+                                     const _FP16 *, _FP16 *)>
                     v_func,
                   Tensor &output) const;
 #endif
index b6e77c369dff04e7ace7bd659cf3fccb86f6b763..0cf53c8947c2e9f49e792eaf5c1cef6077d7ee1a 100644 (file)
@@ -118,7 +118,7 @@ uint TensorDim::getDataTypeSize() const {
   switch (t_type.data_type) {
   case TensorDim::DataType::FP16:
 #ifdef ENABLE_FP16
-    return sizeof(_Float16);
+    return sizeof(_FP16);
 #else
     return 2;
 #endif
index d42764c2bff1391040a8cb3b012eaec240243553..da9f6c86582520e1b0ce41230f6f0944cb186e24 100644 (file)
@@ -40,6 +40,7 @@ float sqrtFloat(float x) { return sqrt(x); };
 double sqrtDouble(double x) { return sqrt(x); };
 
 float logFloat(float x) { return log(x + 1.0e-20); }
+1103
 
 float exp_util(float x) { return exp(x); }
 
index eded7d266be7f3b618d2770b05518a19f78dccbc..4996303c21b39b79a191310ba5e750a300ef917b 100644 (file)
@@ -188,14 +188,13 @@ nntrainer::Tensor ranged(unsigned int batch, unsigned int channel,
                          nntrainer::Tformat fm, nntrainer::Tdatatype d_type) {
   nntrainer::TensorDim::TensorType t_type(fm, d_type);
   nntrainer::Tensor t(batch, channel, height, width, t_type);
-  if (t_type.data_type == nntrainer::Tdatatype::FP32) {
+  // if (t_type.data_type == nntrainer::Tdatatype::FP32) {
     float i = 0;
     t = t.apply((std::function<float(float)>)[&](float in) { return i++; });
-  } else if (t_type.data_type == nntrainer::Tdatatype::FP16) {
-    _Float16 i = 0;
-    t = t.apply(
-      (std::function<_Float16(_Float16)>)[&](_Float16 in) { return i++; });
-  }
+  // } else if (t_type.data_type == nntrainer::Tdatatype::FP16) {
+  //   _FP16 i = 0;
+  //   t = t.apply((std::function<_FP16(_FP16)>)[&](_FP16 in) { return i++; });
+  // }
 
   return t;
 }
index 1f42959f4b226c3d557bce165ca8ac5bf7693af9..add6e8cab3421e766b20a762a3e77e863fea4392 100644 (file)
@@ -44,14 +44,14 @@ auto fc_basic_plain_nhwc = LayerGoldenTestParamType(
   "3:10:1:1", "fc_plain.nnlayergolden",
   LayerGoldenTestParamOptions::SKIP_CALC_DERIV |
     LayerGoldenTestParamOptions::SKIP_CALC_GRAD,
-  "nhwc");
+  "nhwc", "fp32", "fp32");
 
 auto fc_basic_single_batch_nhwc = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>, {"unit=4"},
   "1:10:1:1", "fc_single_batch.nnlayergolden",
   LayerGoldenTestParamOptions::SKIP_CALC_DERIV |
     LayerGoldenTestParamOptions::SKIP_CALC_GRAD,
-  "nhwc");
+  "nhwc", "fp32","fp32");
 
 auto fc_basic_no_decay_nhwc = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>,
@@ -59,10 +59,15 @@ auto fc_basic_no_decay_nhwc = LayerGoldenTestParamType(
   "fc_plain.nnlayergolden",
   LayerGoldenTestParamOptions::SKIP_CALC_DERIV |
     LayerGoldenTestParamOptions::SKIP_CALC_GRAD,
-  "nhwc");
+  "nhwc","fp32","fp32");
+
+auto fc_basic_plain_fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::FullyConnectedLayer>, {"unit=5"},
+  "3:1:1:10", "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw", "fp16", "fp16");
 
 GTEST_PARAMETER_TEST(FullyConnected, LayerGoldenTest,
                      ::testing::Values(fc_basic_plain, fc_basic_single_batch,
                                        fc_basic_no_decay, fc_basic_plain_nhwc,
                                        fc_basic_single_batch_nhwc,
-                                       fc_basic_no_decay_nhwc));
+                                       fc_basic_no_decay_nhwc, fc_basic_plain_fp16));
index 77f48fdff936f0ac735d79ba548d7f239fa8af31..fcbdc088a86b21a61d0a9a6013dce1d9e0fc4937 100644 (file)
@@ -166,7 +166,8 @@ TEST(nntrainer_activation, DISABLED_sigmoidPrime_01_p) {
   nntrainer::Tensor input(batch, channel, height, width);
   GEN_TEST_INPUT(input, (l - 4) * 0.1 * (i + 1));
 
-  nntrainer::Tensor sigmoid_result = input.apply(nntrainer::ActiFunc::sigmoid);
+  nntrainer::Tensor sigmoid_result =
+    input.apply(nntrainer::ActiFunc::sigmoid);
   float *data = sigmoid_result.getData();
   ASSERT_NE(nullptr, data);
 
index 385c4b6d0f0d2b400fef5b61a440242a174b147f..bc286124c87310d4a8cd6920e8dd8c7a265294cb 100644 (file)
@@ -3752,7 +3752,7 @@ TEST(nntrainer_Tensor, average_axis_p) {
   nntrainer::Tensor t = constant(1.0, 2, 2, 2, 2, nntrainer::Tformat::NCHW,
                                  nntrainer::Tdatatype::FP16);
   int idx = 0;
-  std::function<_Float16(_Float16)> f = [&](_Float16 in) { return idx++ % 2; };
+  std::function<float(float)> f = [&](float in) { return idx++ % 2; };
   t = t.apply(f);
 
   nntrainer::Tensor actual, expected;