[ Mixed Precision ] Support Mixed Precision
authorjijoong.moon <jijoong.moon@samsung.com>
Wed, 12 Jul 2023 07:58:48 +0000 (16:58 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 21 Aug 2023 06:29:23 +0000 (15:29 +0900)
This PR enables the Mixed Precision computation.
- Add the data_type property in Tensor : FP16, FP32
- Memory_Data only handle void *
- In Tensor, there were several member function with template
   : getAddress<float>() , getData<__fp16>, etc.
- Need to implement Blas Interface function

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
33 files changed:
jni/Android.mk.in
jni/meson.build
meson_options.txt
nntrainer/dataset/databuffer.cpp
nntrainer/dataset/random_data_producers.cpp
nntrainer/layers/acti_func.cpp
nntrainer/layers/centroid_knn.cpp
nntrainer/layers/concat_layer.cpp
nntrainer/layers/conv2d_layer.cpp
nntrainer/layers/embedding.cpp
nntrainer/layers/gru.cpp
nntrainer/layers/lstm.cpp
nntrainer/layers/lstmcell_core.cpp
nntrainer/layers/mol_attention_layer.cpp
nntrainer/layers/pooling2d_layer.cpp
nntrainer/layers/preprocess_flip_layer.cpp
nntrainer/layers/preprocess_translate_layer.cpp
nntrainer/layers/rnn.cpp
nntrainer/models/dynamic_training_optimization.cpp
nntrainer/tensor/blas_interface.cpp
nntrainer/tensor/blas_interface.h
nntrainer/tensor/cache_elem.cpp
nntrainer/tensor/cache_elem.h
nntrainer/tensor/cache_pool.cpp
nntrainer/tensor/cache_pool.h
nntrainer/tensor/memory_data.h
nntrainer/tensor/memory_pool.cpp
nntrainer/tensor/memory_pool.h
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h
nntrainer/utils/util_func.cpp
nntrainer/utils/util_func.h
tools/package_android.sh

index 5626f32..2dc6ae8 100644 (file)
@@ -20,15 +20,15 @@ include $(PREBUILT_STATIC_LIBRARY)
 
 endif # MESON_HAS_TFLITE
 
-include $(CLEAR_VARS)
+include $(CLEAR_VARS)
 
-LOCAL_MODULE := openblas
+LOCAL_MODULE := openblas
 
-LOCAL_SRC_FILES := @MESON_BLAS_ROOT@/lib/libopenblas.a
-LOCAL_EXPORT_C_INCLUDES := @MESON_BLAS_ROOT@/include
-LOCAL_EXPORT_CFLAGS += -DUSE_BLAS=1
+LOCAL_SRC_FILES := @MESON_BLAS_ROOT@/lib/libopenblas.a
+LOCAL_EXPORT_C_INCLUDES := @MESON_BLAS_ROOT@/include
+LOCAL_EXPORT_CFLAGS += -DUSE_BLAS=1
 
-include $(PREBUILT_STATIC_LIBRARY)
+include $(PREBUILT_STATIC_LIBRARY)
 
 include $(CLEAR_VARS)
 
@@ -66,7 +66,7 @@ LOCAL_MODULE_TAGS   := optional
 
 LOCAL_LDLIBS        := -llog -landroid -fopenmp -static-openmp
 
-LOCAL_STATIC_LIBRARIES += iniparser openblas
+LOCAL_STATIC_LIBRARIES += iniparser #openblas
 
 ifeq ($(MESON_HAS_TFLITE), 1)
   LOCAL_STATIC_LIBRARIES += tensorflow-lite
index 6971185..f780e0d 100644 (file)
@@ -5,6 +5,10 @@ if get_option('debug')
     flags += '-g'
 endif
 
+if get_option('enable-fp16')
+   flags += '-march=armv8.2-a+fp16 -mfpu=neon-fp16 -mfloat-abi=softfp'
+endif
+
 flags += '-O@0@'.format(get_option('optimization'))
 message('compile flags are:' + ' '.join(flags))
 
@@ -35,11 +39,11 @@ else
   and_conf.set('MESON_HAS_TFLITE', 0)
 endif
 
-if blas_dep.found()
-  and_conf.set('MESON_BLAS_ROOT', blas_root)
-else
-  error('blas is needed for the android build')
-endif
+if blas_dep.found()
+  and_conf.set('MESON_BLAS_ROOT', blas_root)
+else
+  error('blas is needed for the android build')
+endif
 
 if ml_api_common_dep.found()
   and_conf.set('MESON_ML_API_COMMON_ROOT', ml_api_common_root)
index 6826f8e..61e659a 100644 (file)
@@ -33,7 +33,8 @@ option('reduce-tolerance', type: 'boolean', value: true)
 option('enable-long-test', type: 'boolean', value: false)
 
 # backend options
-option('enable-blas', type: 'boolean', value: true)
+option('enable-blas', type: 'boolean', value: false)
+option('enable-fp16', type: 'boolean', value: true)
 option('enable-cublas', type: 'boolean', value: false)
 option('enable-openmp', type: 'boolean', value: true)
 
index 6e60038..3143ae1 100644 (file)
@@ -64,7 +64,8 @@ DataBuffer::DataBuffer(std::unique_ptr<DataProducer> &&producer_) :
   producer(std::move(producer_)),
   db_props(new Props()),
   user_data(nullptr) {
-  rng.seed(getSeed());
+  // rng.seed(getSeed());
+  rng.seed(0);
 }
 
 DataBuffer::~DataBuffer(){};
index c6a51fa..6c03fa1 100644 (file)
@@ -126,7 +126,8 @@ RandomDataOneHotProducer::finalize(const std::vector<TensorDim> &input_dims,
                  });
 
   std::mt19937 rng;
-  rng.seed(getSeed());
+  // rng.seed(getSeed());
+  rng.seed(0);
   auto sz = size(input_dims, input_dims);
 
   /** DataProducer::Generator */
index 8fa2b6d..d8fef83 100644 (file)
@@ -224,9 +224,9 @@ Tensor &ActiFunc::softmax(Tensor const &input, Tensor &output) {
 
   for (unsigned int i = 0; i < bch_size; i++) {
     float *ptr = output_data + i * width;
-    std::transform(
-      ptr, ptr + width, ptr,
-      std::bind(std::divides<float>(), std::placeholders::_1, sum.getValue(i)));
+    std::transform(ptr, ptr + width, ptr,
+                   std::bind(std::divides<float>(), std::placeholders::_1,
+                             sum.getValue<float>(i)));
   }
 
   return output;
index c20fb0f..611dca1 100644 (file)
@@ -95,7 +95,7 @@ void CentroidKNN::forwarding(nntrainer::RunLayerContext &context,
 
       //  nntrainer::Tensor::Map(map.getData(), {feature_len},
       // ans[b] * feature_len);
-      auto num_sample = num_samples.getValue(0, 0, 0, ans[b]);
+      float num_sample = num_samples.getValue<float>(0, 0, 0, ans[b]);
       auto current_feature = input_.getBatchSlice(b, 1);
       saved_feature.multiply_i(num_sample);
       saved_feature.add_i(current_feature);
index e792ba9..5d05e12 100644 (file)
@@ -126,10 +126,10 @@ void ConcatLayer::forwarding(RunLayerContext &context, bool training) {
     for (unsigned int batch = 0; batch < output.batch(); batch++) {
       /** loop over the concat dimension itself */
       for (unsigned int count = 0; count < irh.height(); count++) {
-        Tensor dest_tensor = Tensor::Map(
-          output.getAddress(batch, 0, output_height_offset + count, 0),
+        Tensor dest_tensor = Tensor::Map<float>(
+          output.getAddress<float>(batch, 0, output_height_offset + count, 0),
           data_copy_size * sizeof(float), {1, 1, 1, data_copy_size});
-        const Tensor source_tensor = Tensor::Map(
+        const Tensor source_tensor = Tensor::Map<float>(
           input.getAddress(batch, 0, count, 0), data_copy_size * sizeof(float),
           {1, 1, 1, data_copy_size});
         dest_tensor.copy(source_tensor);
@@ -164,10 +164,10 @@ void ConcatLayer::calcDerivative(RunLayerContext &context) {
     for (unsigned int batch = 0; batch < output.batch(); batch++) {
       /** loop over the concat dimension itself */
       for (unsigned int count = 0; count < irh.height(); count++) {
-        const Tensor source_tensor = Tensor::Map(
+        const Tensor source_tensor = Tensor::Map<float>(
           output.getAddress(batch, 0, output_height_offset + count, 0),
           data_copy_size * sizeof(float), {1, 1, 1, data_copy_size});
-        Tensor dest_tensor = Tensor::Map(input.getAddress(batch, 0, count, 0),
+        Tensor dest_tensor = Tensor::Map<float>(input.getAddress(batch, 0, count, 0),
                                          data_copy_size * sizeof(float),
                                          {1, 1, 1, data_copy_size});
         dest_tensor.copy(source_tensor);
index 012587a..9b245a0 100644 (file)
@@ -102,8 +102,8 @@ static void col2im(const Tensor &col_matrix, const TensorDim &kdim,
               continue;
             }
 
-            float *val = image.getAddress(0, c, h, w);
-            *val += col_matrix.getValue(0, 0, col_h, col_w);
+            float *val = (float *)image.getAddress(0, c, h, w);
+            *val += col_matrix.getValue<float>(0, 0, col_h, col_w);
             col_h++;
           }
         }
@@ -229,7 +229,7 @@ static void im2col(const Tensor &in, const TensorDim &kdim,
               im_h++;
               continue;
             }
-            out_data[im_w * owidth + im_h] = in.getValue(0, c, h, w);
+            out_data[im_w * owidth + im_h] = in.getValue<float>(0, c, h, w);
             im_h++;
           }
           im_w++;
index e39fba3..3fbd137 100644 (file)
@@ -81,7 +81,8 @@ void EmbeddingLayer::forwarding(RunLayerContext &context, bool training) {
   TensorDim out_tensor_dim = TensorDim({1, 1, 1, out_dim});
 
   for (unsigned int b = 0; b < input_.batch(); ++b) {
-    float *in_data = input_.getAddress(b * input_.getDim().getFeatureLen());
+    float *in_data =
+      (float *)input_.getAddress(b * input_.getDim().getFeatureLen());
 
     Tensor batchsliced_hidden = hidden_.getBatchSlice(b, 1);
     for (unsigned int i = 0; i < input_.width(); ++i) {
@@ -139,7 +140,7 @@ void EmbeddingLayer::calcGradient(RunLayerContext &context) {
   // In order to accelerate, we need to better way like using index to weight.
 
   for (unsigned int b = 0; b < input_.batch(); ++b) {
-    float *in_data = input_.getAddress(b * input_.getDim().getFeatureLen());
+    float *in_data = (float*)input_.getAddress(b * input_.getDim().getFeatureLen());
 
     for (unsigned int i = 0; i < input_.width(); ++i) {
       uint embed_idx = ((uint *)(in_data))[i];
@@ -148,8 +149,8 @@ void EmbeddingLayer::calcGradient(RunLayerContext &context) {
       // if (embed_idx == 0)
       //   continue;
 
-      float *djdw_data = djdw.getAddress(embed_idx * out_dim);
-      const float *grad_data = derivative_.getAddress(
+      float *djdw_data = (float*)djdw.getAddress(embed_idx * out_dim);
+      const float *grad_data = (float*)derivative_.getAddress(
         b * derivative_.getDim().getFeatureLen() + i * out_dim);
 
       std::transform(djdw_data, djdw_data + out_dim, grad_data, djdw_data,
index 8cf6a08..8f68cb5 100644 (file)
@@ -539,14 +539,15 @@ void GRULayer::calcGradient(RunLayerContext &context) {
     }
   }
   for (unsigned int h = 0; h < unit; ++h) {
-    float *data = djdweight_hh_zr.getAddress(h * unit * 2);
-    float *rdata = djdweight_hh.getAddress(h * unit * NUM_GATE);
+    float *data = (float *)djdweight_hh_zr.getAddress(h * unit * 2);
+    float *rdata = (float *)djdweight_hh.getAddress(h * unit * NUM_GATE);
     std::copy(data, data + unit * 2, rdata);
   }
 
   for (unsigned int h = 0; h < unit; ++h) {
-    float *data = djdweight_hh_g.getAddress(h * unit);
-    float *rdata = djdweight_hh.getAddress(h * unit * NUM_GATE + unit * 2);
+    float *data = (float *)djdweight_hh_g.getAddress(h * unit);
+    float *rdata =
+      (float *)djdweight_hh.getAddress(h * unit * NUM_GATE + unit * 2);
     std::copy(data, data + unit, rdata);
   }
 }
index d235139..bc3d750 100644 (file)
@@ -662,22 +662,23 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
     unsigned int end_timestep = return_sequences ? max_timestep : 1;
     for (unsigned int batch = 0; batch < batch_size; ++batch) {
       for (unsigned int timestep = 0; timestep < end_timestep; ++timestep) {
-        float *hidden_state_data = hidden_state.getAddress(
+        float *hidden_state_data = hidden_state.getAddress<float>(
           batch * max_timestep * unit +
           (return_sequences ? 0 : (max_timestep - 1) * unit) + timestep * unit);
-        float *output_data =
-          output.getAddress(batch * (return_sequences ? max_timestep : 1) *
-                              bidirectional_constant * unit +
-                            timestep * bidirectional_constant * unit);
+        float *output_data = output.getAddress<float>(
+          batch * (return_sequences ? max_timestep : 1) *
+            bidirectional_constant * unit +
+          timestep * bidirectional_constant * unit);
         std::copy(hidden_state_data, hidden_state_data + unit, output_data);
 
         if (bidirectional) {
           Tensor &reverse_hidden_state =
             context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
-          float *reverse_hidden_state_data = reverse_hidden_state.getAddress(
-            batch * max_timestep * unit +
-            (return_sequences ? 0 : (max_timestep - 1) * unit) +
-            timestep * unit);
+          float *reverse_hidden_state_data =
+            reverse_hidden_state.getAddress<float>(
+              batch * max_timestep * unit +
+              (return_sequences ? 0 : (max_timestep - 1) * unit) +
+              timestep * unit);
           std::copy(reverse_hidden_state_data, reverse_hidden_state_data + unit,
                     output_data + unit);
         }
index 4d3c5f0..db32732 100644 (file)
@@ -144,11 +144,11 @@ void LSTMCore::calcGradientLSTM(
   } else {
     for (unsigned int i = 0; i < d_weight_ih.height(); ++i) {
       unsigned int out_width = d_weight_ih.width();
-      float in_ih = input.getValue(i);
+      float in_ih = input.getValue<float>(i);
 
-      float *d_weight_ih_address = d_weight_ih.getAddress(i * out_width);
+      float *d_weight_ih_address = d_weight_ih.getAddress<float>(i * out_width);
 
-      float *d_ifgo_address = d_ifgo.getData();
+      float *d_ifgo_address = d_ifgo.getData<float>();
 #ifdef USE_BLAS
       cblas_saxpy(out_width, in_ih, d_ifgo_address, 1, d_weight_ih_address, 1);
 #else
@@ -164,11 +164,11 @@ void LSTMCore::calcGradientLSTM(
   } else {
     for (unsigned int i = 0; i < d_weight_hh.height(); ++i) {
       unsigned int out_width = d_weight_hh.width();
-      float in_hh = prev_hidden_state.getValue(i);
+      float in_hh = prev_hidden_state.getValue<float>(i);
 
-      float *d_weight_hh_address = d_weight_hh.getAddress(i * out_width);
+      float *d_weight_hh_address = d_weight_hh.getAddress<float>(i * out_width);
 
-      float *d_ifgo_address = d_ifgo.getData();
+      float *d_ifgo_address = d_ifgo.getData<float>();
 
 #ifdef USE_CBLAS
       cblas_saxpy(out_width, in_hh, d_ifgo_address, 1, d_weight_hh_address, 1);
index 64381c5..40c8f31 100644 (file)
@@ -219,7 +219,7 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
   Tensor u_base = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
   for (unsigned int b = 0; b < batch; b++) {
     for (unsigned int h = 0; h < u_base.height(); h++) {
-      float *u_data = u_base.getAddress(b, 0, h, 0);
+      float *u_data = u_base.getAddress<float>(b, 0, h, 0);
       std::fill(u_data, u_data + u_base.width(), h + 1);
     }
   }
index b99d5af..a68e42e 100644 (file)
@@ -207,7 +207,7 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) {
         for (int j = -pt; j <= height_stride_end; j += stride[0]) {
           K = 0;
           for (int k = -pl; k <= width_stride_end; k += stride[1]) {
-            float del = deriv.getValue(b, i, J, K) / *iter;
+            float del = deriv.getValue<float>(b, i, J, K) / *iter;
             int patch_height_end =
               std::min(static_cast<int>(j + p_height), height);
             int patch_width_end =
@@ -216,7 +216,8 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) {
             int start_w = std::max(0, k);
             for (int h = start_h; h < patch_height_end; ++h) {
               for (int w = start_w; w < patch_width_end; ++w) {
-                result.setValue(b, i, h, w, result.getValue(b, i, h, w) + del);
+                result.setValue(b, i, h, w,
+                                result.getValue<float>(b, i, h, w) + del);
               }
             }
             iter++;
index 74253d1..57e4df1 100644 (file)
@@ -30,7 +30,7 @@ PreprocessFlipLayer::PreprocessFlipLayer() :
 void PreprocessFlipLayer::finalize(InitLayerContext &context) {
   context.setOutputDimensions(context.getInputDimensions());
 
-  rng.seed(getSeed());
+  rng.seed(0);
   flip_dist = std::uniform_real_distribution<float>(0.0, 1.0);
 }
 
@@ -81,15 +81,15 @@ void PreprocessFlipLayer::forwarding(RunLayerContext &context, bool training) {
         for (unsigned int c = 0; c < input_dim.channel(); c++)
           for (unsigned int h = 0; h < input_dim.height(); h++)
             for (unsigned int w = 0; w < input_dim.width() / 2; w++)
-              swap(*input_.getAddress(b, c, h, w),
-                   *input_.getAddress(b, c, h, width - w - 1));
+              swap(*input_.getAddress<float>(b, c, h, w),
+                   *input_.getAddress<float>(b, c, h, width - w - 1));
       }
       if (fliph) {
         for (unsigned int c = 0; c < input_dim.channel(); c++)
           for (unsigned int h = 0; h < input_dim.height() / 2; h++)
             for (unsigned int w = 0; w < input_dim.width(); w++)
-              swap(*input_.getAddress(b, c, h, w),
-                   *input_.getAddress(b, c, height - h - 1, w));
+              swap(*input_.getAddress<float>(b, c, h, w),
+                   *input_.getAddress<float>(b, c, height - h - 1, w));
       }
     }
     /** @todo enable inPlace support for this layer */
index fd0a0f6..813af8f 100644 (file)
@@ -38,7 +38,8 @@ void PreprocessTranslateLayer::finalize(InitLayerContext &context) {
   float random_translate =
     std::get<props::RandomTranslate>(preprocess_translate_props);
 
-  rng.seed(getSeed());
+  // rng.seed(getSeed());
+  rng.seed(0);
 
   // Made for 3 channel input
   if (random_translate > epsilon) {
index 38ba4cb..8ac74bd 100644 (file)
@@ -234,9 +234,9 @@ void RNNLayer::forwarding(RunLayerContext &context, bool training) {
 
   if (!return_sequences) {
     for (unsigned int batch = 0; batch < input_dim.batch(); ++batch) {
-      float *hidden_state_data = hidden_state.getAddress(
+      float *hidden_state_data = hidden_state.getAddress<float>(
         batch * unit * max_timestep + (max_timestep - 1) * unit);
-      float *output_data = output.getAddress(batch * unit);
+      float *output_data = output.getAddress<float>(batch * unit);
       std::copy(hidden_state_data, hidden_state_data + unit, output_data);
     }
   } else {
@@ -301,10 +301,11 @@ void RNNLayer::calcGradient(RunLayerContext &context) {
 
   if (!return_sequences) {
     for (unsigned int batch = 0; batch < batch_size; ++batch) {
-      float *hidden_state_derivative_data = hidden_state_derivative.getAddress(
-        batch * unit * max_timestep + (max_timestep - 1) * unit);
+      float *hidden_state_derivative_data =
+        hidden_state_derivative.getAddress<float>(batch * unit * max_timestep +
+                                                  (max_timestep - 1) * unit);
       const float *incoming_derivative_data =
-        incoming_derivative.getAddress(batch * unit);
+        (float *)incoming_derivative.getAddress<float>(batch * unit);
       std::copy(incoming_derivative_data, incoming_derivative_data + unit,
                 hidden_state_derivative_data);
     }
index 569fe6b..a37a569 100644 (file)
@@ -29,7 +29,8 @@ DynamicTrainingOptimization::DynamicTrainingOptimization(int threshold_,
   skip_n_iterations(skip_n_iter) {
   reduce_op = reduceByNorm;
   calc_ratio_op = ratioUsingDerivative;
-  rng.seed(getSeed());
+  // rng.seed(getSeed());
+  rng.seed(0);
   dist = std::uniform_real_distribution<float>(0.0, 1.0);
 }
 
index e413f38..92b2720 100644 (file)
@@ -13,6 +13,7 @@
 
 #include <blas_interface.h>
 #include <nntrainer_error.h>
+#include <iostream>
 
 #include <cmath>
 
@@ -49,6 +50,7 @@ static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
 
   unsigned int incy = abs(incY);
   unsigned int incx = abs(incX);
+
   if (TransA == CblasTrans) {
     sgemv_loop(i, j, N, M);
   } else {
@@ -75,6 +77,15 @@ static void scopy_raw(const unsigned int N, const float *X, const int incX,
     Y[i * incy] = X[i * incx];
 }
 
+static void scopy_FP16(const unsigned int N, const __fp16 *X, const int incX,
+                       __fp16 *Y, const int incY) {
+  unsigned int incy = abs(incY);
+  unsigned int incx = abs(incX);
+
+  for (unsigned int i = 0; i < N; ++i)
+    Y[i * incy] = X[i * incx];
+}
+
 static void sscal_raw(const unsigned int N, const float alpha, float *X,
                       const int incX) {
   unsigned int incx = abs(incX);
@@ -83,6 +94,41 @@ static void sscal_raw(const unsigned int N, const float alpha, float *X,
     X[i * incx] = alpha * X[i * incx];
 }
 
+void sscal(const unsigned int N, const float alpha, __fp16 *X, const int incX) {
+  unsigned int incx = abs(incX);
+
+  for (unsigned int i = 0; i < N; ++i)
+    X[i * incx] = alpha * X[i * incx];
+}
+
+void sscal(const unsigned int N, const float alpha, void *X, const int incX,
+           DataType d_type) {
+#ifdef USE_BLAS
+#ifdef BLAS_NUM_THREADS
+  openblas_set_num_threads(BLAS_NUM_THREADS);
+#endif
+  if (d_type == DataType::FP32)
+    cblas_sscal(N, alpha, (float *)X, incX);
+#else
+  if (d_type == DataType::FP32) {
+    sscal_raw(N, alpha, (float *)X, incX);
+  } else if (d_type == DataType::FP16) {
+    sscal(N, alpha, (__fp16 *)X, incX);
+  }
+#endif
+}
+
+void sscal(const unsigned int N, const float alpha, float *X, const int incX) {
+#ifdef USE_BLAS
+#ifdef BLAS_NUM_THREADS
+  openblas_set_num_threads(BLAS_NUM_THREADS);
+#endif
+  cblas_sscal(N, alpha, (float *)X, incX);
+#else
+  sscal_raw(N, alpha, (float *)X, incX);
+#endif
+}
+
 static float snrm2_raw(const unsigned int N, const float *X, const int incX) {
   unsigned int incx = abs(incX);
   float sum = 0.0f;
@@ -193,28 +239,41 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
 #endif
 }
 
-void scopy(const unsigned int N, const float *X, const int incX, float *Y,
-           const int incY) {
+void scopy(const unsigned int N, const void *X, const int incX, void *Y,
+           const int incY, DataType d_type) {
 #ifdef USE_BLAS
 #ifdef BLAS_NUM_THREADS
   openblas_set_num_threads(BLAS_NUM_THREADS);
 #endif
-  cblas_scopy(N, X, incX, Y, incY);
+  if (d_type == DataType::FP32) {
+    cblas_scopy(N, (float *)X, incX, (float *)Y, incY);
+  }
 #else
-  scopy_raw(N, X, incX, Y, incY);
+  if (d_type == DataType::FP32) {
+    scopy_raw(N, (float *)X, incX, (float *)Y, incY);
+  } else if (d_type == DataType::FP16) {
+    scopy_FP16(N, (__fp16 *)X, incX, (__fp16 *)Y, incY);
+  }
 #endif
-}
+} // namespace nntrainer
 
-void sscal(const int N, const float alpha, float *X, const int incX) {
+void scopy(const unsigned int N, const float *X, const int incX, float *Y,
+           const int incY) {
 #ifdef USE_BLAS
 #ifdef BLAS_NUM_THREADS
   openblas_set_num_threads(BLAS_NUM_THREADS);
 #endif
-  cblas_sscal(N, alpha, X, incX);
+  cblas_scopy(N, X, incX, Y, incY);
 #else
-  sscal_raw(N, alpha, X, incX);
+  scopy_raw(N, X, incX, Y, incY);
 #endif
-}
+} // namespace nntrainer
+
+void scopy(const unsigned int N, const __fp16 *X, const int incX, __fp16 *Y,
+           const int incY) {
+  scopy_FP16(N, X, incX, Y, incY);
+
+} // namespace nntrainer
 
 float snrm2(const int N, const float *X, const int incX) {
 #ifdef USE_BLAS
index 18cf9f5..8bab311 100644 (file)
@@ -27,6 +27,7 @@ enum CBLAS_TRANSPOSE {
   CblasTrans = 112,
   CblasConjTrans = 113
 };
+
 #endif
 
 #ifdef USE_CUBLAS
@@ -36,13 +37,29 @@ enum CBLAS_TRANSPOSE {
 
 namespace nntrainer {
 
-void sscal(const int N, const float alpha, float *X, const int incX);
+enum class DataType {
+  FP16, /** half precion */
+  FP32  /** single precision */
+};
+
+void sscal(const unsigned int N, const float alpha, void *X, const int incX,
+           DataType d_type);
+
+void sscal(const unsigned int N, const float alpha, float *X, const int incX);
+
+void sscal(const unsigned int N, const float alpha, __fp16 *X, const int incX);
 
 float snrm2(const int N, const float *X, const int incX);
 
+void scopy(const unsigned int N, const void *X, const int incX, void *Y,
+           const int incY, DataType d_type);
+
 void scopy(const unsigned int N, const float *X, const int incX, float *Y,
            const int intY);
 
+void scopy(const unsigned int N, const __fp16 *X, const int incX, __fp16 *Y,
+           const int intY);
+
 float sdot(const unsigned int N, const float *X, const unsigned int incX,
            const float *Y, const unsigned int incY);
 
index b037db2..9a4f53c 100644 (file)
@@ -50,7 +50,7 @@ void CacheElem::swapIn(Options opt) {
   void *buf = device->getBuffer(offset, length, alloc_only);
 
   initial_opt = Options::NONE;
-  mem_data->setAddr((float *)buf);
+  mem_data->setAddr((void *)buf);
   mem_data->setValid(true);
   active = true;
 
index e62d30b..de891d4 100644 (file)
@@ -56,7 +56,7 @@ public:
    */
   explicit CacheElem(std::shared_ptr<SwapDevice> dev, unsigned int mem_id,
                      size_t off, size_t len,
-                     std::shared_ptr<MemoryData<float>> data,
+                     std::shared_ptr<MemoryData> data,
                      CachePolicy pol = CachePolicy::ALWAYS_SYNCED) :
     initial_opt(Options::FIRST_ACCESS),
     device(dev),
@@ -123,7 +123,7 @@ private:
   size_t offset;                      /**< element offset from swap device */
   size_t length;                      /**< element size */
   CachePolicy policy;                 /**< cache policy */
-  std::shared_ptr<MemoryData<float>> mem_data; /**< allocated memory data */
+  std::shared_ptr<MemoryData> mem_data; /**< allocated memory data */
 };
 
 } // namespace nntrainer
index 3a4c188..2d555ce 100644 (file)
@@ -168,7 +168,7 @@ unsigned int CachePool::requestMemory(size_t bytes, unsigned int start_time,
   return id;
 }
 
-std::shared_ptr<MemoryData<float>> CachePool::getMemory(unsigned int id) {
+std::shared_ptr<MemoryData> CachePool::getMemory(unsigned int id) {
   NNTR_THROW_IF(!swap_device->isOperating(), std::invalid_argument)
     << "Allocate memory before allocation";
 
@@ -176,7 +176,7 @@ std::shared_ptr<MemoryData<float>> CachePool::getMemory(unsigned int id) {
   size_t len = getMemorySize().at(id - 1);
   auto exe_order = getMemoryExecOrder().at(id - 1);
   auto policy = getCachePolicy().at(id - 1);
-  auto mem_data = std::make_shared<MemoryData<float>>(
+  auto mem_data = std::make_shared<MemoryData>(
     id, std::bind(&CachePool::validate, this, std::placeholders::_1),
     std::bind(&CachePool::invalidate, this, std::placeholders::_1));
   auto elem =
index 8514ebe..1ef8928 100644 (file)
@@ -87,7 +87,7 @@ public:
    *
    * @details This function will throw if called before allocation.
    */
-  virtual std::shared_ptr<MemoryData<float>> getMemory(unsigned int id);
+  virtual std::shared_ptr<MemoryData> getMemory(unsigned int id);
 
   /**
    * @brief Is the cache pool allocated
index ca32d4e..1494333 100644 (file)
@@ -23,13 +23,13 @@ using MemoryDataValidateCallback = std::function<void(unsigned int)>;
 /**
  * @brief  MemoryData Class
  */
-template <typename T = float> class MemoryData {
+class MemoryData {
 public:
   /**
    * @brief  Constructor of Memory Data
    * @param[in] addr Memory data
    */
-  explicit MemoryData(T *addr) :
+  explicit MemoryData(void *addr) :
     valid(true),
     id(0),
     address(addr),
@@ -63,7 +63,7 @@ public:
   /**
    * @brief  Constructor of MemoryData
    */
-  explicit MemoryData(T *addr, MemoryDataValidateCallback v_cb,
+  explicit MemoryData(void *addr, MemoryDataValidateCallback v_cb,
                       MemoryDataValidateCallback i_cb) = delete;
 
   /**
@@ -74,12 +74,12 @@ public:
   /**
    * @brief  Set address
    */
-  void setAddr(T *addr) { address = addr; }
+  void setAddr(void *addr) { address = addr; }
 
   /**
    * @brief  Get address
    */
-  T *getAddr() const { return address; }
+  void *getAddr() const { return address; }
 
   /**
    * @brief  Validate memory data
@@ -107,7 +107,7 @@ public:
 private:
   bool valid;
   unsigned int id;
-  T *address;
+  void *address;
   MemoryDataValidateCallback validate_cb;
   MemoryDataValidateCallback invalidate_cb;
 };
index 457ff14..5be93ea 100644 (file)
@@ -113,12 +113,12 @@ void MemoryPool::allocate() {
  * @brief Get the allocated memory
  *
  */
-std::shared_ptr<MemoryData<float>> MemoryPool::getMemory(unsigned int idx) {
+std::shared_ptr<MemoryData> MemoryPool::getMemory(unsigned int idx) {
   if (mem_pool == nullptr)
     throw std::invalid_argument("Getting memory before allocation");
 
   char *ptr = static_cast<char *>(mem_pool) + memory_offset.at(idx - 1);
-  auto mem_data = std::make_shared<MemoryData<float>>((float *)ptr);
+  auto mem_data = std::make_shared<MemoryData>((void *)ptr);
 
   return mem_data;
 }
index 1d409c5..440c720 100644 (file)
@@ -103,7 +103,7 @@ public:
    *
    * @details This function will throw if called before allocation.
    */
-  virtual std::shared_ptr<MemoryData<float>> getMemory(unsigned int idx);
+  virtual std::shared_ptr<MemoryData> getMemory(unsigned int idx);
 
   /**
    * @brief Free all the allocated memory
index afb9323..f6feef2 100644 (file)
 #include <iostream>
 #include <iterator>
 #include <numeric>
-#include <random>
 #include <regex>
 #include <sstream>
 #include <stdexcept>
 #include <stdio.h>
 
-#include <blas_interface.h>
 #include <lazy_tensor.h>
-#include <nntrainer_error.h>
 #include <nntrainer_log.h>
 #include <tensor.h>
 #include <util_func.h>
           }                                                           \
   } while (0);
 
-#define transposeloop_nhwc(cl, ci, cj, ck, sl, si, sj, sk)            \
-  do {                                                                \
-    unsigned int i, j, k, l;                                          \
-    int inidx = 0, outidx = 0;                                        \
-    for (cl = 0; cl < sl; cl++)                                       \
-      for (ci = 0; ci < si; ci++)                                     \
-        for (cj = 0; cj < sj; cj++)                                   \
-          for (ck = 0; ck < sk; ck++) {                               \
-            outidx = si * sj * sk * cl + sj * sk * ci + sk * cj + ck; \
-            inidx = l * SJ * SK * SI + j * SK * SI + k * SI + i;      \
-            outptr[outidx] = inptr[inidx];                            \
-          }                                                           \
-  } while (0);
-
-#define CREATE_IF_EMPTY_DIMS(tensor, ...) \
-  do {                                    \
-    if (tensor.empty()) {                 \
-      tensor = Tensor(__VA_ARGS__);       \
-    }                                     \
-  } while (0);
 namespace nntrainer {
 
 /**
@@ -106,27 +83,24 @@ struct Tensor::BroadcastInfo {
   Tformat fm;
 };
 
-static auto rng = [] {
-  std::mt19937 rng;
-  rng.seed(getSeed());
-  return rng;
-}();
-
 Tensor::Tensor(const TensorDim &d, bool alloc_now, Tensor::Initializer init,
-               std::string name_) :
-  Tensor(name_, d.getFormat()) {
+               std::string name_, nntrainer::DataType d_type) :
+  Tensor(name_) {
   if (d.getDataLen() != 0) {
     dim = d;
     strides = d.computeStrides();
     initializer = init;
-
+    setDataType(d_type);
     if (alloc_now)
       allocate();
   }
 }
 
-Tensor::Tensor(const TensorDim &d, const float *buf) : Tensor(d, true) {
+Tensor::Tensor(const TensorDim &d, const void *buf,
+               nntrainer::DataType d_type) :
+  Tensor(d, true) {
   if (d.getDataLen() != 0) {
+    setDataType(d_type);
     if (buf != nullptr)
       copy(buf);
   }
@@ -179,39 +153,27 @@ void Tensor::allocate() {
     /** as this memory is shared, do NOT initialize */
   } else {
     /// allocate new memory for the tensor data
-    auto mem_data = new MemoryData<float>(new float[dim.getDataLen()]());
-    data = std::shared_ptr<MemoryData<float>>(mem_data, [](auto *mem_data) {
-      delete[] mem_data->getAddr();
-      delete mem_data;
-    });
+    MemoryData *mem_data;
+
+    if (getDataType() == DataType::FP32) {
+      mem_data = new MemoryData((void *)(new float[dim.getDataLen()]()));
+      data = std::shared_ptr<MemoryData>(mem_data, [](auto *mem_data) {
+        delete[](float *) mem_data->getAddr();
+        delete mem_data;
+      });
+
+    } else if (getDataType() == DataType::FP16) {
+      mem_data = new MemoryData((void *)(new __fp16[dim.getDataLen()]()));
+      data = std::shared_ptr<MemoryData>(mem_data, [](auto *mem_data) {
+        delete[](__fp16 *) mem_data->getAddr();
+        delete mem_data;
+      });
+    }
     offset = 0;
     initialize();
   }
 }
 
-Tensor Tensor::Map(float *buf, unsigned int bytes, const TensorDim &d,
-                   size_t offset) {
-  if (d.getDataLen() == 0 || buf == nullptr) {
-    throw std::invalid_argument(
-      "[Tensor::Map] empty tensor dim is not allowed");
-  }
-
-  if (d.getDataLen() * sizeof(float) + offset > bytes) {
-    throw std::invalid_argument(
-      "Creating shared tensor of size bigger than tensor memory.");
-  }
-
-  Tensor tmp;
-  tmp.dim = d;
-  tmp.strides = d.computeStrides();
-  /// Tensor does not own the memory
-  tmp.data = std::shared_ptr<MemoryData<float>>(
-    new MemoryData<float>(buf), std::default_delete<MemoryData<float>>());
-  tmp.offset = offset;
-
-  return tmp;
-}
-
 bool Tensor::operator==(const Tensor &rhs) const {
   if (this->dim != rhs.dim)
     return false;
@@ -221,50 +183,49 @@ bool Tensor::operator==(const Tensor &rhs) const {
   if (len != rhs.size())
     return false;
 
-  const float *data = getData();
-  const float *rdata = rhs.getData();
-
   if (contiguous != rhs.contiguous)
     return false;
 
   if (strides != rhs.strides)
     return false;
 
-  for (size_t i = 0; i < len; ++i) {
-    /** not checking sign change is intentional to avoid float calculation
-     * errors around 0 */
-    if ((std::isnan(data[i]) && !std::isnan(rdata[i])) ||
-        (!std::isnan(data[i]) && std::isnan(rdata[i])) ||
-        std::fabs(data[i] - rdata[i]) > epsilon)
-      return false;
+  if (data_type == nntrainer::DataType::FP32) {
+    const float *_data = getData<float>();
+    const float *_rdata = rhs.getData<float>();
+    for (size_t i = 0; i < len; ++i) {
+      /** not checking sign change is intentional to avoid float calculation
+       * errors around 0 */
+      if ((std::isnan(_data[i]) && !std::isnan(_rdata[i])) ||
+          (!std::isnan(_data[i]) && std::isnan(_rdata[i])) ||
+          std::fabs(_data[i] - _rdata[i]) > epsilon)
+        return false;
+    }
+  } else if (data_type == nntrainer::DataType::FP16) {
+    const __fp16 *_data = getData<__fp16>();
+    const __fp16 *_rdata = rhs.getData<__fp16>();
+    for (size_t i = 0; i < len; ++i) {
+      if ((std::isnan(_data[i]) && !std::isnan(_rdata[i])) ||
+          (!std::isnan(_data[i]) && std::isnan(_rdata[i])) ||
+          std::fabs(_data[i] - _rdata[i]) > epsilon)
+        return false;
+    }
   }
 
   return true;
 }
 
-template <typename T> void Tensor::setDist(T dist) {
-  NNTR_THROW_IF(!contiguous, std::invalid_argument)
-    << getName() << " Tensor is not contiguous, cannot set distribution";
-
-  float *data = getData();
-  unsigned int len = size();
-  for (unsigned int i = 0; i < len; ++i) {
-    data[i] = dist(rng);
-  }
-}
-
 void Tensor::setRandNormal(float mean, float std) {
-  setDist<std::normal_distribution<float>>(
+  setDist<float, std::normal_distribution<float>>(
     std::normal_distribution<float>(mean, std));
 }
 
 void Tensor::setRandUniform(float min, float max) {
-  setDist<std::uniform_real_distribution<float>>(
+  setDist<float, std::uniform_real_distribution<float>>(
     std::uniform_real_distribution<float>(min, max));
 }
 
 void Tensor::setRandBernoulli(float probability) {
-  setDist<std::bernoulli_distribution>(
+  setDist<float, std::bernoulli_distribution>(
     std::bernoulli_distribution(probability));
 }
 
@@ -326,57 +287,6 @@ void Tensor::initialize() {
   putData();
 }
 
-Tensor::Tensor(
-  std::vector<std::vector<std::vector<std::vector<float>>>> const &d,
-  Tformat fm) {
-
-  if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
-    throw std::out_of_range(
-      "[Tensor] trying to initialize Tensor from empty vector");
-  }
-
-  // if fm == Tformat::NCHW, then dim[0] == batch , dim[1] == channel, dim[2] ==
-  // height, dim[3] == width. and if fm == Tformat::NHWC, dim[0] == batch,
-  // dim[1] == height, dim[2] == width, dim[3] == channel
-  dim.setTensorDim(0, d.size());
-  if (fm == Tformat::NCHW) {
-    dim.setTensorDim(1, d[0].size());
-    dim.setTensorDim(2, d[0][0].size());
-    dim.setTensorDim(3, d[0][0][0].size());
-  } else {
-    dim.setTensorDim(2, d[0].size());
-    dim.setTensorDim(3, d[0][0].size());
-    dim.setTensorDim(1, d[0][0][0].size());
-  }
-
-  dim.setFormat(fm);
-
-  strides = dim.computeStrides();
-  auto mem_data = new MemoryData<float>(new float[dim.getDataLen()]);
-  data = std::shared_ptr<MemoryData<float>>(
-    mem_data, [](auto *mem_data) { delete[] mem_data->getAddr(); });
-  offset = 0;
-  contiguous = true;
-  initializer = Initializer::NONE;
-
-  // if fm == Tformat::NCHW, then dim[0] == batch , dim[1] == channel, dim[2] ==
-  // height, dim[3] == width. and if fm == Tformat::NHWC, dim[0] == batch,
-  // dim[1] == height, dim[2] == width, dim[3] == channel
-  if (fm == Tformat::NCHW) {
-    for (unsigned int i = 0; i < batch(); ++i)
-      for (unsigned int j = 0; j < channel(); ++j)
-        for (unsigned int k = 0; k < height(); ++k)
-          for (unsigned int l = 0; l < width(); ++l)
-            this->setValue(i, j, k, l, d[i][j][k][l]);
-  } else {
-    for (unsigned int i = 0; i < batch(); ++i)
-      for (unsigned int j = 0; j < height(); ++j)
-        for (unsigned int k = 0; k < width(); ++k)
-          for (unsigned int l = 0; l < channel(); ++l)
-            this->setValue(i, l, j, k, d[i][j][k][l]);
-  }
-}
-
 int Tensor::multiply_i_strided(Tensor const &m, const float beta) {
   try {
     this->multiply_strided(m, *this, beta);
@@ -396,20 +306,12 @@ Tensor Tensor::multiply_strided(Tensor const &m, const float beta) const {
 Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
                                  const float beta) const {
   /** TODO: throw than create new dimenions */
-  CREATE_IF_EMPTY_DIMS(output, dim);
+  CREATE_IF_EMPTY_DIMS(output, dim, nullptr, data_type);
 
   if (size() != m.size() || size() != output.size())
     throw std::invalid_argument(
       "Strided multiplication does not support broadcasting");
-
-  NNTR_THROW_IF(getData() == nullptr, std::invalid_argument)
-    << getName() << " is not allocated";
-  NNTR_THROW_IF(m.getData() == nullptr, std::invalid_argument)
-    << m.getName() << " is not allocated";
-  NNTR_THROW_IF(output.getData() == nullptr, std::invalid_argument)
-    << output.getName() << " is not allocated";
-
-  if (this->getFormat() == Tformat::NCHW) {
+  if (data_type == DataType::FP32) {
     if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 ||
         beta != 0.0) {
       for (unsigned int b = 0; b < batch(); ++b) {
@@ -417,7 +319,8 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
           for (unsigned int h = 0; h < height(); ++h) {
             for (unsigned int w = 0; w < width(); ++w) {
               output.addValue(b, c, h, w,
-                              getValue(b, c, h, w) * m.getValue(b, c, h, w),
+                              getValue<float>(b, c, h, w) *
+                                m.getValue<float>(b, c, h, w),
                               beta);
             }
           }
@@ -437,15 +340,16 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
         }
       }
     }
-  } else {
+  } else if (data_type == DataType::FP16) {
     if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 ||
         beta != 0.0) {
       for (unsigned int b = 0; b < batch(); ++b) {
-        for (unsigned int h = 0; h < height(); ++h) {
-          for (unsigned int w = 0; w < width(); ++w) {
-            for (unsigned int c = 0; c < channel(); ++c) {
+        for (unsigned int c = 0; c < channel(); ++c) {
+          for (unsigned int h = 0; h < height(); ++h) {
+            for (unsigned int w = 0; w < width(); ++w) {
               output.addValue(b, c, h, w,
-                              getValue(b, c, h, w) * m.getValue(b, c, h, w),
+                              getValue<__fp16>(b, c, h, w) *
+                                m.getValue<__fp16>(b, c, h, w),
                               beta);
             }
           }
@@ -454,13 +358,13 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
     } else {
       /** @todo optimize this with combining these loops where stride is 1 */
       for (unsigned int b = 0; b < batch(); ++b) {
-        for (unsigned int h = 0; h < height(); ++h) {
-          for (unsigned int w = 0; w < width(); ++w) {
-            float *out_data = output.getAddress(b, 0, h, w);
-            const float *m_data = m.getAddress(b, 0, h, w);
-            const float *in_data = getAddress(b, 0, h, w);
-            std::transform(in_data, in_data + channel(), m_data, out_data,
-                           std::multiplies<float>());
+        for (unsigned int c = 0; c < channel(); ++c) {
+          for (unsigned int h = 0; h < height(); ++h) {
+            __fp16 *out_data = output.getAddress<__fp16>(b, c, h, 0);
+            const __fp16 *m_data = m.getAddress<__fp16>(b, c, h, 0);
+            const __fp16 *in_data = getAddress<__fp16>(b, c, h, 0);
+            std::transform(in_data, in_data + width(), m_data, out_data,
+                           std::multiplies<__fp16>());
           }
         }
       }
@@ -489,20 +393,12 @@ Tensor Tensor::add_strided(Tensor const &m, const float beta) const {
 Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
                             const float beta) const {
   /** TODO: throw than create new dimenions */
-  CREATE_IF_EMPTY_DIMS(output, dim);
+  CREATE_IF_EMPTY_DIMS(output, dim, nullptr, data_type);
 
   if (size() != m.size() || size() != output.size())
     throw std::invalid_argument(
       "Strided addition does not support broadcasting");
-
-  NNTR_THROW_IF(getData() == nullptr, std::invalid_argument)
-    << getName() << " is not allocated";
-  NNTR_THROW_IF(m.getData() == nullptr, std::invalid_argument)
-    << m.getName() << " is not allocated";
-  NNTR_THROW_IF(output.getData() == nullptr, std::invalid_argument)
-    << output.getName() << " is not allocated";
-
-  if (this->getFormat() == Tformat::NCHW) {
+  if (data_type == DataType::FP32) {
     if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 ||
         beta != 0.0) {
       for (unsigned int b = 0; b < batch(); ++b) {
@@ -510,8 +406,8 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
           for (unsigned int h = 0; h < height(); ++h) {
             for (unsigned int w = 0; w < width(); ++w) {
               output.setValue(b, c, h, w,
-                              getValue(b, c, h, w) +
-                                m.getValue(b, c, h, w) * beta);
+                              getValue<float>(b, c, h, w) +
+                                m.getValue<float>(b, c, h, w) * beta);
             }
           }
         }
@@ -521,25 +417,25 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
       for (unsigned int b = 0; b < batch(); ++b) {
         for (unsigned int c = 0; c < channel(); ++c) {
           for (unsigned int h = 0; h < height(); ++h) {
-            float *out_data = output.getAddress(b, c, h, 0);
-            const float *m_data = m.getAddress(b, c, h, 0);
-            const float *in_data = getAddress(b, c, h, 0);
+            float *out_data = output.getAddress<float>(b, c, h, 0);
+            const float *m_data = m.getAddress<float>(b, c, h, 0);
+            const float *in_data = getAddress<float>(b, c, h, 0);
             std::transform(in_data, in_data + width(), m_data, out_data,
                            std::plus<float>());
           }
         }
       }
     }
-  } else {
+  } else if (data_type == DataType::FP16) {
     if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 ||
         beta != 0.0) {
       for (unsigned int b = 0; b < batch(); ++b) {
-        for (unsigned int h = 0; h < height(); ++h) {
-          for (unsigned int w = 0; w < width(); ++w) {
-            for (unsigned int c = 0; c < channel(); ++c) {
+        for (unsigned int c = 0; c < channel(); ++c) {
+          for (unsigned int h = 0; h < height(); ++h) {
+            for (unsigned int w = 0; w < width(); ++w) {
               output.setValue(b, c, h, w,
-                              getValue(b, c, h, w) +
-                                m.getValue(b, c, h, w) * beta);
+                              getValue<__fp16>(b, c, h, w) +
+                                m.getValue<__fp16>(b, c, h, w) * beta);
             }
           }
         }
@@ -547,19 +443,18 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
     } else {
       /** @todo optimize this with combining these loops where stride is 1 */
       for (unsigned int b = 0; b < batch(); ++b) {
-        for (unsigned int h = 0; h < height(); ++h) {
-          for (unsigned int w = 0; w < width(); ++w) {
-            float *out_data = output.getAddress(b, 0, h, w);
-            const float *m_data = m.getAddress(b, 0, h, w);
-            const float *in_data = getAddress(b, 0, h, w);
-            std::transform(in_data, in_data + channel(), m_data, out_data,
-                           std::plus<float>());
+        for (unsigned int c = 0; c < channel(); ++c) {
+          for (unsigned int h = 0; h < height(); ++h) {
+            __fp16 *out_data = output.getAddress<__fp16>(b, c, h, 0);
+            const __fp16 *m_data = m.getAddress<__fp16>(b, c, h, 0);
+            const __fp16 *in_data = getAddress<__fp16>(b, c, h, 0);
+            std::transform(in_data, in_data + width(), m_data, out_data,
+                           std::plus<__fp16>());
           }
         }
       }
     }
   }
-
   return output;
 }
 
@@ -569,10 +464,16 @@ int Tensor::multiply_i(float const &value) {
 
   /// @note this is not depending on multiply_i as there is an optimized
   /// version for multiply_i
-  float *data = getData();
-  unsigned int len = size();
+  if (data_type == DataType::FP32) {
+    float *data = getData<float>();
+    unsigned int len = size();
 
-  sscal(len, value, data, 1);
+    sscal(len, value, data, 1);
+  } else if (data_type == DataType::FP16) {
+    __fp16 *data = getData<__fp16>();
+    unsigned int len = size();
+    sscal(len, value, data, 1);
+  }
   return ML_ERROR_NONE;
 }
 
@@ -583,8 +484,13 @@ Tensor Tensor::multiply(float const &value) const {
 
 Tensor &Tensor::multiply(float const &value, Tensor &out) const {
   /// @todo add unittest
-  auto f = std::bind(std::multiplies<float>(), std::placeholders::_1, value);
-  return apply(f, out);
+  if (data_type == DataType::FP32) {
+    auto f = std::bind(std::multiplies<float>(), std::placeholders::_1, value);
+    return apply(f, out);
+  } else if (data_type == DataType::FP16) {
+    auto f = std::bind(std::multiplies<__fp16>(), std::placeholders::_1, value);
+    return apply(f, out);
+  }
 }
 
 int Tensor::multiply_i(Tensor const &m, const float beta) {
@@ -609,33 +515,53 @@ Tensor &Tensor::multiply(Tensor const &m, Tensor &output,
    * @note this does not work correctly with differently strided inputs.
    * Use multiply_strided alternatively
    */
-  auto f = [&](const BroadcastInfo &e, const float *buf, const float *m_buf,
-               float *out_buf) {
-    if (e.strides[3] == 1 && output.strides[3] == 1 && strides[3] == 1 &&
-        beta == 0.0) {
-      std::transform(buf, buf + e.buffer_size, m_buf, out_buf,
-                     std::multiplies<float>());
-    } else {
-      for (unsigned int i = 0; i < e.buffer_size; ++i) {
-        *out_buf = *buf * *m_buf + beta * *out_buf;
-        buf += strides[3];
-        m_buf += e.strides[3];
-        out_buf += output.strides[3];
+  if (data_type == DataType::FP32) {
+    auto f = [&](const BroadcastInfo &e, const float *buf, const float *m_buf,
+                 float *out_buf) {
+      if (e.strides[3] == 1 && output.strides[3] == 1 && strides[3] == 1 &&
+          beta == 0.0) {
+        std::transform(buf, buf + e.buffer_size, m_buf, out_buf,
+                       std::multiplies<float>());
+      } else {
+        for (unsigned int i = 0; i < e.buffer_size; ++i) {
+          *out_buf = *buf * *m_buf + beta * *out_buf;
+          buf += strides[3];
+          m_buf += e.strides[3];
+          out_buf += output.strides[3];
+        }
       }
-    }
-  };
-
-  NNTR_THROW_IF(m.getFormat() != this->getFormat(), std::invalid_argument)
-    << "Tensor Format of " << getName() << ":"
-    << ((bool)(this->getFormat()) ? "NHWC" : "NCHW") << " is not match. ("
-    << ((bool)(m.getFormat()) ? "NHWC" : "NCHW") << ")";
+    };
+
+    NNTR_THROW_IF(!contiguous || !m.contiguous || !output.contiguous,
+                  std::invalid_argument)
+      << getName() << " is not contiguous, cannot multiply";
+
+    apply_broadcast(m, f, output);
+    return output;
+  } else if (data_type == DataType::FP16) {
+    auto f = [&](const BroadcastInfo &e, const __fp16 *buf, const __fp16 *m_buf,
+                 __fp16 *out_buf) {
+      if (e.strides[3] == 1 && output.strides[3] == 1 && strides[3] == 1 &&
+          beta == 0.0) {
+        std::transform(buf, buf + e.buffer_size, m_buf, out_buf,
+                       std::multiplies<__fp16>());
+      } else {
+        for (unsigned int i = 0; i < e.buffer_size; ++i) {
+          *out_buf = *buf * *m_buf + beta * *out_buf;
+          buf += strides[3];
+          m_buf += e.strides[3];
+          out_buf += output.strides[3];
+        }
+      }
+    };
 
-  NNTR_THROW_IF(!contiguous || !m.contiguous || !output.contiguous,
-                std::invalid_argument)
-    << getName() << " is not contiguous, cannot multiply";
+    NNTR_THROW_IF(!contiguous || !m.contiguous || !output.contiguous,
+                  std::invalid_argument)
+      << getName() << " is not contiguous, cannot multiply";
 
-  apply_broadcast(m, f, output);
-  return output;
+    apply_broadcast(m, f, output);
+    return output;
+  }
 }
 
 int Tensor::divide_i(float const &value) {
@@ -1038,12 +964,10 @@ Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
                                   [axis](unsigned cur, const Tensor &t) {
                                     return cur += t.getDim().getTensorDim(axis);
                                   });
-  auto iter_value =
-    [is_format_nchw](std::array<unsigned, 4> &loc,
-                     const std::array<unsigned, 4> &start_loc, Tensor &t,
-                     const std::array<unsigned, 4> &ref_dim_arr) -> float & {
-    auto &value = is_format_nchw ? t.getValue(loc[0], loc[1], loc[2], loc[3])
-                                 : t.getValue(loc[0], loc[3], loc[1], loc[2]);
+  auto iter_value = [](std::array<unsigned, 4> &loc,
+                       std::array<unsigned, 4> &start_loc, Tensor &t,
+                       const TensorDim &ref_dim) -> float & {
+    auto &value = t.getValue<float>(loc[0], loc[1], loc[2], loc[3]);
     for (int i = 3; i >= 0; --i) {
       loc[i]++;
       if (loc[i] - start_loc[i] == ref_dim_arr[i]) {
@@ -1077,19 +1001,7 @@ Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
     }
 
     for (size_t i = 0u, sz = t.size(); i < sz; ++i) {
-      iter_value(loc, start_loc, ret, tensor_dim_arr) = t.getValue(i);
-    }
-
-    if (is_format_nchw) {
-      loc[axis] += t.getDim().getTensorDim(axis);
-    } else {
-      if (axis == 0) {
-        loc[0] += t.getDim().getTensorDim(axis);
-      } else if (axis == 1) {
-        loc[3] += t.getDim().getTensorDim(axis);
-      } else if (axis == 2 || axis == 3) {
-        loc[axis - 1] += t.getDim().getTensorDim(axis);
-      }
+      iter_value(loc, start_loc, ret, t.getDim()) = t.getValue<float>(i);
     }
   }
 
@@ -1142,6 +1054,35 @@ void Tensor::apply_broadcast(
   return apply_broadcast_util(m, v_func, output, this->computeBroadcastInfo(m));
 }
 
+void Tensor::apply_broadcast(
+  Tensor const &m,
+  std::function<void(const BroadcastInfo &e, const __fp16 *, const __fp16 *,
+                     __fp16 *)>
+    v_func,
+  Tensor &output) const {
+  CREATE_IF_EMPTY_DIMS(output, dim, nullptr, data_type);
+
+  NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument)
+    << getName() << " is not allocated";
+  NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument)
+    << m.getName() << " is not allocated";
+  NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument)
+    << output.getName() << " is not allocated";
+
+  /// shortcut to cover when dimension matches
+  /// note that buffer_size, the last stride is only used in v_func but it
+  /// might be changed
+  if (dim == m.dim) {
+    BroadcastInfo e;
+    e.buffer_size = size();
+    e.strides[3] = 1;
+    v_func(e, getData<__fp16>(), m.getData<__fp16>(), output.getData<__fp16>());
+    return;
+  }
+
+  return apply_broadcast_util(m, v_func, output, this->computeBroadcastInfo(m));
+}
+
 void Tensor::apply_broadcast_util(
   Tensor const &m,
   std::function<void(const BroadcastInfo &e, const float *, const float *,
@@ -1174,6 +1115,32 @@ void Tensor::apply_broadcast_util(
   }
 }
 
+void Tensor::apply_broadcast_util(
+  Tensor const &m,
+  std::function<void(const BroadcastInfo &e, const __fp16 *, const __fp16 *,
+                     __fp16 *)>
+    v_func,
+  Tensor &output, const BroadcastInfo &e, int cur_axis, size_t offset,
+  size_t m_offset) const {
+
+  const __fp16 *buf = this->getData<__fp16>();
+  const __fp16 *m_buf = m.getData<__fp16>();
+  __fp16 *out_buf = output.getData<__fp16>();
+
+  if (e.buffer_axis == cur_axis) {
+    v_func(e, buf + offset, m_buf + m_offset, out_buf + offset);
+    return;
+  }
+
+  cur_axis++;
+  for (unsigned int i = 0; i < dim.getTensorDim(cur_axis); ++i) {
+    size_t next_offset = offset + i * strides[cur_axis];
+    size_t next_m_offset = m_offset + i * e.strides[cur_axis];
+    apply_broadcast_util(m, v_func, output, e, cur_axis, next_offset,
+                         next_m_offset);
+  }
+}
+
 /**
  * This is to sum the Tensor data according to the dim.batch().
  * Therefore the result has M(dim.batch(), 1, 1, 1) dimension.
@@ -1206,7 +1173,7 @@ Tensor Tensor::sum(unsigned int axis, float alpha) const {
 }
 Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
                     float beta) const {
-  const float *data = getData();
+  const float *data = getData<float>();
 
   NNTR_THROW_IF(!contiguous, std::invalid_argument)
     << getName() << " is not contiguous, cannot sum";
@@ -1232,81 +1199,44 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
           ones.getData(), 1, beta, ret.getData(), 1);
   } break;
   case 1: {
-    CREATE_IF_EMPTY_DIMS(ret, dim[0], 1, dim[2], dim[3], this->getFormat());
-    if (this->getFormat() == Tformat::NHWC) {
-      unsigned int m = ret.dim.getDataLen();
-      unsigned int n = dim[1];
-      Tensor ones(1, 1, 1, n, this->getFormat());
-      ones.setValue(alpha);
-      sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n, ones.getData(), 1,
-            beta, ret.getData(), 1);
-    } else {
-      unsigned int feat_len = dim[2] * dim[3];
-      unsigned int t_axis = dim[1];
-      Tensor ones(1, 1, 1, t_axis);
-      ones.setValue(alpha);
-      float *rdata = ret.getData();
-      for (unsigned int k = 0; k < dim[0]; ++k) {
-        sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1,
-              &data[k * dim.getFeatureLen()], feat_len, ones.getData(), 1, beta,
-              &rdata[k * feat_len], 1);
-      }
+    CREATE_IF_EMPTY_DIMS(ret, dim.batch(), 1, dim.height(), dim.width());
+    unsigned int feat_len = dim.height() * dim.width();
+    unsigned int channel = dim.channel();
+    Tensor ones(1, 1, 1, channel);
+    ones.setValue(alpha);
+    float *rdata = ret.getData<float>();
+    for (unsigned int k = 0; k < dim.batch(); ++k) {
+      sgemv(CblasRowMajor, CblasTrans, channel, feat_len, 1,
+            &data[k * dim.getFeatureLen()], feat_len, ones.getData(), 1, beta,
+            &rdata[k * feat_len], 1);
     }
   } break;
   case 2: {
-    CREATE_IF_EMPTY_DIMS(ret, dim[0], dim[1], 1, dim[3], this->getFormat());
-
-    if (this->getFormat() == Tformat::NHWC) {
-      unsigned int feat_len = dim[1] * dim[3];
-      unsigned int t_axis = dim[2];
-      Tensor ones(1, 1, 1, t_axis, this->getFormat());
-      ones.setValue(alpha);
-      float *rdata = ret.getData();
-      for (unsigned int k = 0; k < dim[0]; ++k) {
-        sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1,
-              &data[k * dim.getFeatureLen()], feat_len, ones.getData(), 1, beta,
-              &rdata[k * feat_len], 1);
-      }
-    } else {
-      unsigned int t_3 = dim[3];
-      unsigned int t_axis = dim[2];
-      Tensor ones(1, 1, 1, t_axis, this->getFormat());
-      ones.setValue(alpha);
-      float *rdata = ret.getData();
-      for (unsigned int k = 0; k < dim[0]; ++k) {
-        for (unsigned int c = 0; c < dim[1]; ++c) {
-          unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[2];
-          unsigned int ridx = k * ret.dim.getFeatureLen() + c * dim[3];
-          sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3,
-                ones.getData(), 1, beta, &rdata[ridx], 1);
-        }
+    CREATE_IF_EMPTY_DIMS(ret, dim.batch(), dim.channel(), 1, dim.width());
+    unsigned int width = dim.width();
+    unsigned int height = dim.height();
+    Tensor ones(1, 1, 1, height);
+    ones.setValue(alpha);
+    float *rdata = ret.getData<float>();
+    for (unsigned int k = 0; k < dim.batch(); ++k) {
+      for (unsigned int c = 0; c < dim.channel(); ++c) {
+        unsigned int idx =
+          k * dim.getFeatureLen() + c * dim.width() * dim.height();
+        unsigned int ridx = k * ret.dim.getFeatureLen() + c * dim.width();
+        sgemv(CblasRowMajor, CblasTrans, height, width, 1, &data[idx], width,
+              ones.getData(), 1, beta, &rdata[ridx], 1);
       }
     }
   } break;
   case 3: {
-    CREATE_IF_EMPTY_DIMS(ret, dim[0], dim[1], dim[2], 1, this->getFormat());
-    if (this->getFormat() == Tformat::NHWC) {
-      unsigned int t_3 = dim[1];
-      unsigned int t_axis = dim[3];
-      Tensor ones(1, 1, 1, t_axis, this->getFormat());
-      ones.setValue(alpha);
-      float *rdata = ret.getData();
-      for (unsigned int k = 0; k < dim[0]; ++k) {
-        for (unsigned int c = 0; c < dim[2]; ++c) {
-          unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[1];
-          unsigned int ridx = k * ret.dim.getFeatureLen() + c * dim[1];
-          sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3,
-                ones.getData(), 1, beta, &rdata[ridx], 1);
-        }
-      }
-    } else {
-      unsigned int m = ret.dim.getDataLen();
-      unsigned int n = dim[3];
-      Tensor ones(1, 1, 1, n);
-      ones.setValue(alpha);
-      sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n, ones.getData(), 1,
-            beta, ret.getData(), 1);
-    }
+    CREATE_IF_EMPTY_DIMS(ret, dim.batch(), dim.channel(), dim.height(), 1,
+                         Tformat::NCHW, DataType::FP32);
+    unsigned int m = ret.dim.getDataLen();
+    unsigned int n = dim.width();
+    Tensor ones(1, 1, 1, n);
+    ones.setValue(alpha);
+    sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n, ones.getData<float>(),
+          1, beta, ret.getData<float>(), 1);
   } break;
   default:
     throw std::out_of_range("Error: Dimension cannot exceed 3");
@@ -1737,56 +1667,56 @@ void Tensor::zoneout_mask(Tensor &opposite, float zoneout) {
   }
 }
 
-int Tensor::apply_i(std::function<float(float)> f) {
-  Tensor result = *this;
-  apply(f, result);
-
-  return ML_ERROR_NONE;
-}
-
-Tensor Tensor::apply(std::function<float(float)> f) const {
-  Tensor result;
-  return apply(f, result);
-}
-
-Tensor &Tensor::apply(std::function<float(float)> f, Tensor &output) const {
-  CREATE_IF_EMPTY_DIMS(output, dim);
-
-  if (dim != output.dim) {
-    /// @todo add unittest
-    throw std::invalid_argument(
-      "[Tensor::apply] output dimension does not match");
-  }
-
-  if (contiguous && output.contiguous) {
-    const float *data = getData();
-    float *rdata = output.getData();
-    std::transform(data, data + size(), rdata, f);
-  } else if (strides[3] == 1 && output.strides[3] == 1) {
-    /** @todo optimize this with combining these loops where stride is 1 */
-    for (unsigned int b = 0; b < batch(); ++b) {
-      for (unsigned int c = 0; c < channel(); ++c) {
-        for (unsigned int h = 0; h < height(); ++h) {
-          float *out_data = output.getAddress(b, c, h, 0);
-          const float *in_data = getAddress(b, c, h, 0);
-          std::transform(in_data, in_data + width(), out_data, f);
-        }
-      }
-    }
-  } else {
-    for (unsigned int b = 0; b < batch(); ++b) {
-      for (unsigned int c = 0; c < channel(); ++c) {
-        for (unsigned int h = 0; h < height(); ++h) {
-          for (unsigned int w = 0; w < width(); ++w) {
-            output.setValue(b, c, h, w, f(getValue(b, c, h, w)));
-          }
-        }
-      }
-    }
-  }
-
-  return output;
-}
+// int Tensor::apply_i(std::function<float(float)> f) {
+//   Tensor result = *this;
+//   apply(f, result);
+
+//   return ML_ERROR_NONE;
+// }
+
+// Tensor Tensor::apply(std::function<float(float)> f) const {
+//   Tensor result;
+//   return apply(f, result);
+// }
+
+// Tensor &Tensor::apply(std::function<float(float)> f, Tensor &output) const {
+//   CREATE_IF_EMPTY_DIMS(output, dim);
+
+//   if (dim != output.dim) {
+//     /// @todo add unittest
+//     throw std::invalid_argument(
+//       "[Tensor::apply] output dimension does not match");
+//   }
+
+//   if (contiguous && output.contiguous) {
+//     const float *data = getData();
+//     float *rdata = output.getData();
+//     std::transform(data, data + size(), rdata, f);
+//   } else if (strides[3] == 1 && output.strides[3] == 1) {
+//     /** @todo optimize this with combining these loops where stride is 1 */
+//     for (unsigned int b = 0; b < batch(); ++b) {
+//       for (unsigned int c = 0; c < channel(); ++c) {
+//         for (unsigned int h = 0; h < height(); ++h) {
+//           float *out_data = output.getAddress(b, c, h, 0);
+//           const float *in_data = getAddress(b, c, h, 0);
+//           std::transform(in_data, in_data + width(), out_data, f);
+//         }
+//       }
+//     }
+//   } else {
+//     for (unsigned int b = 0; b < batch(); ++b) {
+//       for (unsigned int c = 0; c < channel(); ++c) {
+//         for (unsigned int h = 0; h < height(); ++h) {
+//           for (unsigned int w = 0; w < width(); ++w) {
+//             output.setValue(b, c, h, w, f(getValue(b, c, h, w)));
+//           }
+//         }
+//       }
+//     }
+//   }
+
+//   return output;
+// }
 
 Tensor Tensor::apply(std::function<Tensor(Tensor)> f) const { return f(*this); }
 
@@ -1812,74 +1742,12 @@ void Tensor::print(std::ostream &out) const {
 
   std::ios init(NULL);
   init.copyfmt(out);
-  if (getFormat() == Tformat::NCHW) {
-    for (unsigned int k = 0; k < batch(); k++) {
-      for (unsigned int l = 0; l < channel(); l++) {
-        for (unsigned int i = 0; i < height(); i++) {
-          for (unsigned int j = 0; j < width(); j++) {
-            out << std::setw(10) << std::setprecision(10)
-                << this->getValue(k, l, i, j) << " ";
-          }
-          out << std::endl;
-        }
-        out << std::endl;
-      }
-      out << "-------" << std::endl;
-    }
-  } else {
-    for (unsigned int k = 0; k < batch(); k++) {
-      for (unsigned int i = 0; i < height(); i++) {
-        for (unsigned int j = 0; j < width(); j++) {
-          for (unsigned int l = 0; l < channel(); l++) {
-            out << std::setw(10) << std::setprecision(10)
-                << this->getValue(k, l, i, j) << " ";
-          }
-          out << std::endl;
-        }
-        out << std::endl;
-      }
-      out << "-------" << std::endl;
-    }
-  }
-
-  out.copyfmt(init);
-}
-
-void Tensor::print_(std::ostream &out, uint opt) const {
-  printInstance(out, this);
-
-  unsigned int len = size();
-
-  std::ios init(NULL);
-  init.copyfmt(out);
-  if (opt == 0) {
-    if (getFormat() == Tformat::NCHW) {
-      out << "{";
-      for (unsigned int k = 0; k < batch(); k++) {
-        out << "{";
-        for (unsigned int i = 0; i < channel(); i++) {
-          out << "{";
-          for (unsigned int j = 0; j < height(); j++) {
-            out << "{";
-            for (unsigned int l = 0; l < width(); l++) {
-              if (l < width() - 1)
-                out << std::setw(10) << std::setprecision(10)
-                    << this->getValue(k, l, i, j) << ", ";
-              else
-                out << std::setw(10) << std::setprecision(10)
-                    << this->getValue(k, l, i, j);
-            }
-            if (j < height() - 1)
-              out << "},";
-            else
-              out << "}";
-            out << std::endl;
-          }
-          if (i < channel() - 1)
-            out << "},";
-          else
-            out << "}";
-          out << std::endl;
+  for (unsigned int k = 0; k < dim.batch(); k++) {
+    for (unsigned int l = 0; l < dim.channel(); l++) {
+      for (unsigned int i = 0; i < dim.height(); i++) {
+        for (unsigned int j = 0; j < dim.width(); j++) {
+          out << std::setw(10) << std::setprecision(10)
+              << this->getValue<float>(k, l, i, j) << " ";
         }
         if (k < batch() - 1)
           out << "},";
@@ -1937,7 +1805,7 @@ std::ostream &operator<<(std::ostream &out, Tensor const &m) {
   return out;
 }
 
-void Tensor::copy(const float *buf) {
+void Tensor::copy(const void *buf) {
   NNTR_THROW_IF(!contiguous, std::invalid_argument)
     << getName() << "Tensor is not contiguous, cannot copy.";
 
@@ -1945,7 +1813,7 @@ void Tensor::copy(const float *buf) {
     return;
   }
 
-  scopy(size(), buf, 1, getData(), 1);
+  scopy(size(), buf, 1, getData(), 1, getDataType());
 }
 
 void Tensor::copy_with_stride(const Tensor &from) {
@@ -1955,7 +1823,7 @@ void Tensor::copy_with_stride(const Tensor &from) {
       for (unsigned int c = 0; c < channel(); ++c) {
         for (unsigned int h = 0; h < height(); ++h) {
           for (unsigned int w = 0; w < width(); ++w) {
-            setValue(b, c, h, w, from.getValue(b, c, h, w));
+            setValue(b, c, h, w, from.getValue<float>(b, c, h, w));
           }
         }
       }
@@ -1966,7 +1834,7 @@ void Tensor::copy_with_stride(const Tensor &from) {
       for (unsigned int c = 0; c < t.channel(); ++c) {
         for (unsigned int h = 0; h < t.height(); ++h) {
           for (unsigned int w = 0; w < t.width(); ++w) {
-            t.setValue(b, c, h, w, from.getValue(b, c, h, w));
+            t.setValue(b, c, h, w, from.getValue<float>(b, c, h, w));
           }
         }
       }
@@ -2149,15 +2017,22 @@ void Tensor::setValue(float val) {
   NNTR_THROW_IF(!contiguous, std::invalid_argument)
     << getName() << " is not contiguous, cannot set value.";
 
-  float *data = getData();
+  float *data = getData<float>();
   std::fill(data, data + size(), val);
 }
 
 void Tensor::setZero() {
-  if (contiguous)
-    sscal(size(), 0, getData(), 1);
-  else
-    apply_i([](float val) -> float { return 0; });
+  if (data_type == nntrainer::DataType::FP32) {
+    if (contiguous)
+      sscal(size(), 0, getData<float>(), 1);
+    else
+      apply_i([](float val) -> float { return 0; });
+  } else if (data_type == nntrainer::DataType::FP16) {
+    if (contiguous)
+      sscal(size(), 0, getData<__fp16>(), 1);
+    else
+      apply_i([](__fp16 val) -> __fp16 { return 0; });
+  }
 }
 
 std::vector<unsigned int> Tensor::argmax() const {
@@ -2338,4 +2213,21 @@ Tensor::BroadcastInfo Tensor::computeBroadcastInfo(const Tensor &m) const {
   return e;
 }
 
+Tensor Tensor::rotate_180(Tensor in) {
+  Tensor output(in.getDim());
+  output.setZero();
+  for (unsigned int i = 0; i < in.batch(); ++i) {
+    for (unsigned int j = 0; j < in.channel(); ++j) {
+      for (unsigned int k = 0; k < in.height(); ++k) {
+        for (unsigned int l = 0; l < in.width(); ++l) {
+          output.setValue(i, j, k, l,
+                          in.getValue<float>(i, j, (in.height() - k - 1),
+                                             (in.width() - l - 1)));
+        }
+      }
+    }
+  }
+  return output;
+}
+
 } /* namespace nntrainer */
index 9206e20..9307e94 100644 (file)
 #include <array>
 #include <functional>
 #include <memory>
+#include <random>
 #include <stdexcept>
 #include <vector>
 
+#include <blas_interface.h>
+#include <iostream>
 #include <memory_data.h>
+#include <nntrainer_error.h>
 #include <tensor_dim.h>
+#include <util_func.h>
 
 #ifdef DEBUG
 #define EXCEPT_WHEN_DEBUG
 #endif
 
 #define MAKE_SHARED_TENSOR(...) std::make_shared<nntrainer::Tensor>(__VA_ARGS__)
+#define CREATE_IF_EMPTY_DIMS(tensor, ...) \
+  do {                                    \
+    if (tensor.empty())                   \
+      tensor = Tensor(__VA_ARGS__);       \
+  } while (0);
 
 namespace nntrainer {
 
@@ -73,18 +83,35 @@ public:
     NONE            /** No initialization */
   };
 
+  void setSizeOf(nntrainer::DataType d_type) {
+    switch (d_type) {
+    case nntrainer::DataType::FP16:
+      sizeof_d = sizeof(__fp16);
+      return;
+    case nntrainer::DataType::FP32:
+      sizeof_d = sizeof(float);
+      return;
+    default:
+      return;
+    }
+  }
+
   /**
    * @brief     Basic Constructor of Tensor
    */
-  Tensor(std::string name_ = "", Tformat fm = Tformat::NCHW) :
+  Tensor(std::string name_ = "", Tformat fm = Tformat::NCHW,
+         nntrainer::DataType d_type = nntrainer::DataType::FP32) :
     dim(TensorDim(fm)),
     strides(dim.computeStrides()),
     contiguous(true),
     initializer(Initializer::NONE),
     name(name_),
+    data_type(d_type),
     data(nullptr),
     offset(0),
-    src_tensor() {}
+    src_tensor() {
+    setSizeOf(d_type);
+  }
 
   /**
    * @brief     Constructor of Tensor with dimension, possibly lazily
@@ -94,7 +121,8 @@ public:
    * @param name Name of the tensor
    */
   Tensor(const TensorDim &d, bool alloc_now,
-         Initializer init = Initializer::NONE, std::string name = "");
+         Initializer init = Initializer::NONE, std::string name = "",
+         nntrainer::DataType d_type = nntrainer::DataType::FP32);
 
   /**
    * @brief     Constructor of Tensor with dimension/buf
@@ -102,7 +130,8 @@ public:
    * @param buf buffer
    * @note Memory for this tensor is instantaneously allocated
    */
-  Tensor(const TensorDim &d, const float *buf = nullptr);
+  Tensor(const TensorDim &d, const void *buf = nullptr,
+         nntrainer::DataType d_type = nntrainer::DataType::FP32);
 
   /**
    * @brief     Constructor of Tensor
@@ -111,9 +140,9 @@ public:
    * @param[in] d2 Height
    * @param[in] d3 Width
    */
-  Tensor(size_t d0, size_t d1, size_t d2, size_t d3,
-         Tformat fm = Tformat::NCHW) :
-    Tensor(TensorDim(d0, d1, d2, d3, fm)){};
+  Tensor(size_t d0, size_t d1, size_t d2, size_t d3, Tformat fm = Tformat::NCHW,
+         nntrainer::DataType d_type = nntrainer::DataType::FP32) :
+    Tensor(TensorDim(d0, d1, d2, d3, fm), nullptr, d_type){};
 
   /**
    * @brief     Constructor of Tensor
@@ -121,32 +150,60 @@ public:
    * @param[in] d2 Height
    * @param[in] d3 Width
    */
-  Tensor(size_t d1, size_t d2, size_t d3, Tformat fm = Tformat::NCHW) :
-    Tensor(1, d1, d2, d3, fm){};
+  Tensor(size_t d1, size_t d2, size_t d3, Tformat fm = Tformat::NCHW,
+         nntrainer::DataType d_type = nntrainer::DataType::FP32) :
+    Tensor(1, d1, d2, d3, fm, d_type){};
 
   /**
    * @brief     Constructor of Tensor with batch size one and d1 size one
    * @param[in] d2 Height (NCHW) or Width (NHWC)
    * @param[in] d3 Width (NCHW) or Channel (NHWC)
    */
-  Tensor(size_t d2, size_t d3, Tformat fm = Tformat::NCHW) :
-    Tensor(1, (fm == Tformat::NCHW) ? 1 : d3, (fm == Tformat::NCHW) ? d2 : 1,
-           (fm == Tformat::NCHW) ? d3 : d2, fm){};
+  Tensor(size_t d2, size_t d3, Tformat fm = Tformat::NCHW,
+         nntrainer::DataType d_type = nntrainer::DataType::FP32) :
+    Tensor(1, 1, d2, d3, fm, d_type){};
 
   /**
    * @brief     Constructor of Tensor with just Width or Channel
    * @param[in] d3 Width (NCHW) or Channel (NHWC)
    */
-  explicit Tensor(size_t d3, Tformat fm = Tformat::NCHW) :
-    Tensor(1, (fm == Tformat::NCHW) ? 1 : d3, 1, (fm == Tformat::NCHW) ? d3 : 1,
-           fm){};
+  explicit Tensor(size_t d3, Tformat fm = Tformat::NCHW,
+                  nntrainer::DataType d_type = nntrainer::DataType::FP32) :
+    Tensor(1, 1, 1, d3, fm, d_type){};
 
   /**
    * @brief     Constructor of Tensor
    * @param[in] d data for the Tensor. It needs to set format properly.
    */
-  Tensor(std::vector<std::vector<std::vector<std::vector<float>>>> const &d,
-         Tformat fm = Tformat::NCHW);
+  Tensor(std::vector<std::vector<std::vector<std::vector<float>>>> const &d) {
+
+    if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
+      throw std::out_of_range(
+        "[Tensor] trying to initialize Tensor from empty vector");
+    }
+
+    dim.batch(d.size());
+    dim.channel(d[0].size());
+    dim.height(d[0][0].size());
+    dim.width(d[0][0][0].size());
+    strides = dim.computeStrides();
+
+    auto mem_data = new MemoryData((void *)(new float[dim.getDataLen()]()));
+    data = std::shared_ptr<MemoryData>(
+      mem_data, [](auto *mem_data) { delete[](float *) mem_data->getAddr(); });
+    offset = 0;
+    contiguous = true;
+    initializer = Initializer::NONE;
+
+    setDataType(DataType::FP32);
+
+    for (unsigned int i = 0; i < dim.batch(); ++i)
+      for (unsigned int j = 0; j < dim.channel(); ++j)
+        for (unsigned int k = 0; k < dim.height(); ++k)
+          for (unsigned int l = 0; l < dim.width(); ++l) {
+            this->setValue(i, j, k, l, d[i][j][k][l]);
+          }
+  };
 
   /**
    * @brief     Constructor of Tensor
@@ -165,6 +222,51 @@ public:
   Tensor(std::vector<std::vector<float>> const &d, Tformat fm = Tformat::NCHW) :
     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, fm){};
 
+  Tensor(std::vector<std::vector<std::vector<std::vector<__fp16>>>> const &d) {
+
+    if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
+      throw std::out_of_range(
+        "[Tensor] trying to initialize Tensor from empty vector");
+    }
+
+    dim.batch(d.size());
+    dim.channel(d[0].size());
+    dim.height(d[0][0].size());
+    dim.width(d[0][0][0].size());
+    strides = dim.computeStrides();
+
+    auto mem_data = new MemoryData((void *)(new __fp16[dim.getDataLen()]()));
+    data = std::shared_ptr<MemoryData>(
+      mem_data, [](auto *mem_data) { delete[](__fp16 *) mem_data->getAddr(); });
+    offset = 0;
+    contiguous = true;
+    initializer = Initializer::NONE;
+
+    setDataType(DataType::FP16);
+
+    for (unsigned int i = 0; i < dim.batch(); ++i)
+      for (unsigned int j = 0; j < dim.channel(); ++j)
+        for (unsigned int k = 0; k < dim.height(); ++k)
+          for (unsigned int l = 0; l < dim.width(); ++l)
+            this->setValue(i, j, k, l, d[i][j][k][l]);
+  };
+
+  /**
+   * @brief     Constructor of Tensor
+   * @note      This constructor copies vector again. needs refactoring
+   * @param[in] d data for the Tensor
+   */
+  Tensor(std::vector<std::vector<std::vector<__fp16>>> const &d) :
+    Tensor(std::vector<std::decay<decltype(d)>::type>{d}){};
+
+  /**
+   * @brief     Constructor of Tensor
+   * @note      This constructor copies vector again. needs refactoring
+   * @param[in] d data for the Tensor with batch size one
+   */
+  Tensor(std::vector<std::vector<__fp16>> const &d) :
+    Tensor(std::vector<std::decay<decltype(d)>::type>{d}){};
+
   /**
    *  @brief  Copy constructor of Tensor.
    *  @param[in] Tensor &
@@ -200,8 +302,29 @@ public:
    * @return Tensor object
    * @throws std::invalid_argument if buf is null
    */
-  static Tensor Map(float *buf, unsigned int bytes, const TensorDim &d,
-                    size_t offset = 0);
+  template <typename T = float>
+  static Tensor Map(T *buf, unsigned int bytes, const TensorDim &d,
+                    size_t offset = 0) {
+    if (d.getDataLen() == 0 || buf == nullptr) {
+      throw std::invalid_argument(
+        "[Tensor::Map] empty tensor dim is not allowed");
+    }
+
+    if (d.getDataLen() * sizeof(T) + offset > bytes) {
+      throw std::invalid_argument(
+        "Creating shared tensor of size bigger than tensor memory.");
+    }
+
+    Tensor tmp;
+    tmp.dim = d;
+    tmp.strides = d.computeStrides();
+    /// Tensor does not own the memory
+    tmp.data = std::shared_ptr<MemoryData>(new MemoryData((void *)buf),
+                                           std::default_delete<MemoryData>());
+    tmp.offset = offset;
+
+    return tmp;
+  };
 
   friend void swap(Tensor &lhs, Tensor &rhs) noexcept {
     std::swap(lhs.dim, rhs.dim);
@@ -250,29 +373,34 @@ public:
    * @param[in] h height location
    * @param[in] w width location
    */
-  const float &getValue(unsigned int batch, unsigned int c, unsigned int h,
-                        unsigned int w) const noexcept {
-    return getValue(getIndex(batch, c, h, w));
+  template <typename T = float>
+  const T &getValue(unsigned int batch, unsigned int c, unsigned int h,
+                    unsigned int w) const noexcept {
+    return getValue<T>(getIndex(batch, c, h, w));
   }
 
-  float &getValue(unsigned int batch, unsigned int c, unsigned int h,
-                  unsigned int w) noexcept {
-    return getValue(getIndex(batch, c, h, w));
+  template <typename T = float>
+  T &getValue(unsigned int batch, unsigned int c, unsigned int h,
+              unsigned int w) noexcept {
+    return getValue<T>(getIndex(batch, c, h, w));
   }
 
   /**
    * @brief     return value at specific location
    * @param[in] idx location
    */
-  const float &getValue(unsigned int idx) const noexcept {
-    return getData()[idx];
+  template <typename T = float>
+  const T &getValue(unsigned int idx) const noexcept {
+    return getData<T>()[idx];
   }
 
   /**
    * @brief     return value at specific location
    * @param[in] idx location
    */
-  float &getValue(unsigned int idx) noexcept { return getData()[idx]; }
+  template <typename T = float> T &getValue(unsigned int idx) noexcept {
+    return getData<T>()[idx];
+  }
 
   /**
    * @brief Get the Value thinking that it is padded
@@ -292,9 +420,11 @@ public:
    * @param pw padding width
    * @return float value
    */
-  float getValuePaddedVirtual(unsigned int b, unsigned int c, unsigned int h,
-                              unsigned int w, unsigned int ph, unsigned int pw,
-                              float pad_value = 0) const EXCEPT_WHEN_DEBUG {
+  template <typename T = float>
+  const T getValuePaddedVirtual(unsigned int b, unsigned int c, unsigned int h,
+                                unsigned int w, unsigned int ph,
+                                unsigned int pw,
+                                T pad_value = 0) const EXCEPT_WHEN_DEBUG {
 #if DEBUG
     unsigned int padded_h = 2 * ph + h;
     unsigned int padded_w = 2 * pw + w;
@@ -305,7 +435,7 @@ public:
 #endif
 
     if (ph <= h && h < ph + height() && pw <= w && w < pw + width()) {
-      return getValue(b, c, h - ph, w - pw);
+      return getValue<T>(b, c, h - ph, w - pw);
     }
 
     return pad_value;
@@ -604,6 +734,27 @@ public:
   Tensor &pow(float exponent, Tensor &out) const;
 
   /**
+   * @brief  gaussian error function
+   * @return int ML_ERROR_NONE if successful
+   */
+  int erf_i();
+
+  /**
+   * @brief    gaussian error function
+   * @retval Calculated Tensor
+   */
+  Tensor erf() const;
+
+  /**
+   * @brief    gaussian error function
+   * @param[out] out out to store the result
+   * @retval Calculated Tensor
+   */
+  Tensor &erf(Tensor &out) const;
+
+  unsigned int sizeofData() { return sizeof_d; }
+
+  /**
    * @brief     Dot Product of Tensor ( equal MxM )
    * @details   This applies dot of the last dimension of this and second-last
    * dimension of passed tensor m.
@@ -875,20 +1026,67 @@ public:
    */
   void standardization_i();
 
+  template <typename T = float> T *getAddress(unsigned int i) {
+    size_t index = getIndex(batch(), channel(), height(), width());
+    if (i > index) {
+      return nullptr;
+    }
+    return &getData<T>()[i];
+  }
+
+  /**
+   * @brief     i data index
+   * @retval    address of ith data
+   */
+  template <typename T = float> const T *getAddress(unsigned int i) const {
+    size_t index = getIndex(batch(), channel(), height(), width());
+    if (i > index) {
+      return nullptr;
+    }
+
+    return &getData<T>()[i];
+  }
+
+  /**
+   * @brief    get address of n-d data
+   */
+  template <typename T = float>
+  T *getAddress(unsigned int b, unsigned int c, unsigned int h,
+                unsigned int w) {
+    return getAddress<T>(getIndex(b, c, h, w));
+  }
+
+  /**
+   * @brief    get address of n-d data
+   */
+  template <typename T = float>
+  const T *getAddress(unsigned int b, unsigned int c, unsigned int h,
+                      unsigned int w) const {
+    return getAddress<T>(getIndex(b, c, h, w));
+  }
+
   /**
    * @brief Apply instantly to the element
    *
    * @param f function to apply
    * @return int ML_ERROR_NONE if successful
    */
-  int apply_i(std::function<float(float)> f);
+  int apply_i(std::function<float(float)> f) {
+    Tensor result = *this;
+    apply(f, result);
+
+    return ML_ERROR_NONE;
+  };
 
   /**
    * @brief     Apply function element by element
    * @param[in] *function function pointer applied
    * @retval    Tensor
    */
-  Tensor apply(std::function<float(float)> f) const;
+  Tensor apply(std::function<float(float)> f) const {
+    Tensor result;
+    return apply(f, result);
+  };
 
   /**
    * @brief     Apply function element by element
@@ -896,7 +1094,76 @@ public:
    * @param[out] output output tensor
    * @retval    Tensor
    */
-  Tensor &apply(std::function<float(float)> f, Tensor &output) const;
+  Tensor &apply(std::function<float(float)> f, Tensor &output) const {
+    CREATE_IF_EMPTY_DIMS(output, dim, nullptr, data_type);
+
+    if (dim != output.dim) {
+      /// @todo add unittest
+      throw std::invalid_argument(
+        "[Tensor::apply] output dimension does not match");
+    }
+
+    if (data_type == nntrainer::DataType::FP32) {
+      if (contiguous && output.contiguous) {
+        const float *data = (getData<float>());
+        float *rdata = (output.getData<float>());
+
+        std::transform(data, data + size(), rdata, f);
+      } else if (strides[3] == 1 && output.strides[3] == 1) {
+        /** @todo optimize this with combining these loops where stride is 1 */
+        for (unsigned int b = 0; b < batch(); ++b) {
+          for (unsigned int c = 0; c < channel(); ++c) {
+            for (unsigned int h = 0; h < height(); ++h) {
+              float *out_data = output.getAddress<float>(b, c, h, 0);
+              const float *in_data = getAddress<float>(b, c, h, 0);
+              std::transform(in_data, in_data + width(), out_data, f);
+            }
+          }
+        }
+      } else {
+        for (unsigned int b = 0; b < batch(); ++b) {
+          for (unsigned int c = 0; c < channel(); ++c) {
+            for (unsigned int h = 0; h < height(); ++h) {
+              for (unsigned int w = 0; w < width(); ++w) {
+                output.setValue(b, c, h, w, f(getValue<float>(b, c, h, w)));
+              }
+            }
+          }
+        }
+      }
+    } else if (data_type == nntrainer::DataType::FP16) {
+      if (contiguous && output.contiguous) {
+        const __fp16 *data = (getData<__fp16>());
+        __fp16 *rdata = (output.getData<__fp16>());
+
+        std::transform(data, data + size(), rdata, f);
+      } else if (strides[3] == 1 && output.strides[3] == 1) {
+        /** @todo optimize this with combining these loops where stride is 1 */
+        for (unsigned int b = 0; b < batch(); ++b) {
+          for (unsigned int c = 0; c < channel(); ++c) {
+            for (unsigned int h = 0; h < height(); ++h) {
+              __fp16 *out_data = (__fp16 *)output.getAddress(b, c, h, 0);
+              const __fp16 *in_data = (__fp16 *)getAddress(b, c, h, 0);
+              std::transform(in_data, in_data + width(), out_data, f);
+            }
+          }
+        }
+      } else {
+        for (unsigned int b = 0; b < batch(); ++b) {
+          for (unsigned int c = 0; c < channel(); ++c) {
+            for (unsigned int h = 0; h < height(); ++h) {
+              for (unsigned int w = 0; w < width(); ++w) {
+                output.setValue(b, c, h, w,
+                                f((float)((__fp16)getValue(b, c, h, w))));
+              }
+            }
+          }
+        }
+      }
+    }
+
+    return output;
+  };
 
   /**
    * @brief     Apply function to Tensor
@@ -946,7 +1213,7 @@ public:
    * @brief     Get size of the data in bytes
    * @retval    size_t Size in bytes
    */
-  size_t bytes() const { return size() * sizeof(float); }
+  size_t bytes() const { return size() * sizeof_d; }
 
   /**
    * @brief     Set the element value
@@ -958,7 +1225,11 @@ public:
    */
   void setValue(unsigned int batch, unsigned int c, unsigned int h,
                 unsigned int w, float value) noexcept {
-    getData()[getIndex(batch, c, h, w)] = value;
+    if (data_type == nntrainer::DataType::FP32) {
+      getData<float>()[getIndex(batch, c, h, w)] = value;
+    } else if (data_type == nntrainer::DataType::FP16) {
+      getData<__fp16>()[getIndex(batch, c, h, w)] = value;
+    }
   }
 
   /**
@@ -973,7 +1244,11 @@ public:
   void addValue(unsigned int batch, unsigned int c, unsigned int h,
                 unsigned int w, float value, float beta) noexcept {
     auto const &idx = getIndex(batch, c, h, w);
-    getData()[idx] = value + getData()[idx] * beta;
+    if (data_type == nntrainer::DataType::FP32) {
+      *(float *)(getData(idx)) = value + *(float *)(getData(idx)) * beta;
+    } else if (data_type == nntrainer::DataType::FP16) {
+      *(__fp16 *)(getData(idx)) = value + *(__fp16 *)(getData(idx)) * beta;
+    }
   }
 
   /**
@@ -1001,6 +1276,23 @@ public:
   void setZero();
 
   /**
+   * @brief Set the Dist object
+   *
+   * @tparam T distrubution engine
+   * @param dist distribution engine
+   */
+  template <typename T, typename Engine> void setDist(Engine dist) {
+    NNTR_THROW_IF(!contiguous, std::invalid_argument)
+      << getName() << " Tensor is not contiguous, cannot set distribution";
+
+    T *data_ = getData<T>();
+    unsigned int len = size();
+    for (unsigned int i = 0; i < len; ++i) {
+      data_[i] = (T)dist(rng);
+    }
+  };
+
+  /**
    * @brief     Set the tensor with random normal distribution
    * @param[in] mean mean of the distribution
    * @param[in] std standard deviation of the distribution
@@ -1238,7 +1530,7 @@ public:
       return nullptr;
 
     data->validate();
-    return (T *)((float *)data->getAddr() + offset);
+    return (T *)((T *)(data->getAddr()) + offset);
   }
 
   /**
@@ -1250,7 +1542,26 @@ public:
       return nullptr;
 
     data->validate();
-    return (T *)((float *)data->getAddr() + offset);
+    return (T *)((T *)data->getAddr() + offset);
+  }
+
+  /**
+   * @brief     return Data pointer of Tensor
+   * @retval    template T pointer (float pointer as default)
+   */
+  template <typename T = float> T *getData(size_t idx) const {
+    if (!data)
+      return nullptr;
+
+    size_t index = idx * sizeof_d;
+
+    data->validate();
+    return (T *)((T *)data->getAddr() + offset + index);
+  }
+
+  void setDataType(nntrainer::DataType d_type) {
+    data_type = d_type;
+    setSizeOf(data_type);
   }
 
   /**
@@ -1269,10 +1580,7 @@ public:
    * @brief     return Data pointer of Tensor
    * @retval    template T pointer (float pointer as default)
    */
-  template <typename T = float>
-  const std::shared_ptr<MemoryData<T>> getMemoryData() const {
-    return data;
-  }
+  const std::shared_ptr<MemoryData> getMemoryData() const { return data; }
 
   /**
    * @brief     return offset
@@ -1283,44 +1591,6 @@ public:
    * @brief     i data index
    * @retval    address of ith data
    */
-  template <typename T = float> T *getAddress(unsigned int i) {
-    if (i > getIndex(batch(), channel(), height(), width())) {
-      return nullptr;
-    }
-
-    return &getData<T>()[i];
-  }
-
-  /**
-   * @brief     i data index
-   * @retval    address of ith data
-   */
-  template <typename T = float> const T *getAddress(unsigned int i) const {
-    if (i > getIndex(batch(), channel(), height(), width())) {
-      return nullptr;
-    }
-
-    return &getData<T>()[i];
-  }
-
-  /**
-   * @brief    get address of n-d data
-   */
-  template <typename T = float>
-  T *getAddress(unsigned int b, unsigned int c, unsigned int h,
-                unsigned int w) {
-    return getAddress<T>(getIndex(b, c, h, w));
-  }
-
-  /**
-   * @brief    get address of n-d data
-   */
-  template <typename T = float>
-  const T *getAddress(unsigned int b, unsigned int c, unsigned int h,
-                      unsigned int w) const {
-    return getAddress<T>(getIndex(b, c, h, w));
-  }
-
   /**
    * @brief     set Tensor Dim
    * @param[in] d TensorDim
@@ -1395,8 +1665,8 @@ public:
    * @param buf the memory buffer
    * @param init intialize the buffer
    */
-  void setData(const std::shared_ptr<MemoryData<float>> buf,
-               unsigned int off = 0, bool init = false) {
+  void setData(const std::shared_ptr<MemoryData> buf, unsigned int off = 0,
+               bool init = false) {
     if (buf) {
       data = buf;
       offset = off;
@@ -1422,6 +1692,13 @@ public:
    */
   TensorDim::Format getFormat() const { return dim.getFormat(); }
 
+  /**
+   * @brief Get data type for the tensor
+   *
+   * @return data type of the tensor
+   */
+  nntrainer::DataType getDataType() const { return data_type; }
+
   static constexpr float epsilon = 1e-5;
 
 private:
@@ -1431,8 +1708,10 @@ private:
   bool contiguous;
   Tensor::Initializer initializer;
   std::string name; /**< name of the tensor */
-  std::shared_ptr<MemoryData<float>> data;
+  nntrainer::DataType data_type;
+  std::shared_ptr<MemoryData> data;
   unsigned int offset;
+  unsigned int sizeof_d;
 
   /**<
    * When using shared_data with tensor, this stores the ptr of the source
@@ -1464,6 +1743,14 @@ private:
                        int cur_axis = -1, size_t offset = 0,
                        size_t m_offset = 0) const;
 
+  void apply_broadcast_util(
+    Tensor const &m,
+    std::function<void(const BroadcastInfo &e, const __fp16 *, const __fp16 *,
+                       __fp16 *)>
+      v_func,
+    Tensor &output, const BroadcastInfo &e, int cur_axis = -1,
+    size_t offset = 0, size_t m_offset = 0) const;
+
   /**
    * @brief Applies the given operator to the tensor with the passed argument
    *
@@ -1478,6 +1765,13 @@ private:
                          v_func,
                        Tensor &output) const;
 
+  void
+  apply_broadcast(Tensor const &m,
+                  std::function<void(const BroadcastInfo &e, const __fp16 *,
+                                     const __fp16 *, __fp16 *)>
+                    v_func,
+                  Tensor &output) const;
+
   /**
    * @brief compute Loop info for broadcasting and vectorization
    *
@@ -1487,20 +1781,12 @@ private:
   BroadcastInfo computeBroadcastInfo(const Tensor &m) const;
 
   /**
-   * @brief Set the Dist object
-   *
-   * @tparam T distrubution engine
-   * @param dist distribution engine
-   */
-  template <typename T> void setDist(T dist);
-
-  /**
    * @brief copy a buffer to @a this, the caller has to ensure that @a this is
    * initialized otherwise undefined behavior
    *
    * @param buf buffer to copy from
    */
-  void copy(const float *buf);
+  void copy(const void *buf);
 
   /**
    * @brief Update destination tensor to share memory with source tensor
@@ -1535,6 +1821,14 @@ private:
    * @param axis2 second axis to merge
    */
   void mergeAxis(unsigned int axis1, unsigned int axis2);
+
+  /**
+   * @brief     rotate 180 dgree
+   * @param[in] in input Tensor
+   * @retVal Tensor rotated tensor (180 degree)
+   */
+  Tensor rotate_180(Tensor in);
+
 }; // namespace nntrainer
 
 /**
index 0c5b864..d42764c 100644 (file)
 
 namespace nntrainer {
 
-static auto rng = [] {
-  std::mt19937 rng;
-  rng.seed(getSeed());
-  return rng;
-}();
 static std::uniform_real_distribution<float> dist(-0.5, 0.5);
 
-unsigned int getSeed() { return 0; }
-
 float sqrtFloat(float x) { return sqrt(x); };
 
 double sqrtDouble(double x) { return sqrt(x); };
@@ -50,23 +43,6 @@ float logFloat(float x) { return log(x + 1.0e-20); }
 
 float exp_util(float x) { return exp(x); }
 
-Tensor rotate_180(Tensor in) {
-  Tensor output(in.getDim());
-  output.setZero();
-  for (unsigned int i = 0; i < in.batch(); ++i) {
-    for (unsigned int j = 0; j < in.channel(); ++j) {
-      for (unsigned int k = 0; k < in.height(); ++k) {
-        for (unsigned int l = 0; l < in.width(); ++l) {
-          output.setValue(
-            i, j, k, l,
-            in.getValue(i, j, (in.height() - k - 1), (in.width() - l - 1)));
-        }
-      }
-    }
-  }
-  return output;
-}
-
 bool isFileExist(std::string file_name) {
   std::ifstream infile(file_name);
   return infile.good();
index f6fde9a..68ba151 100644 (file)
 #include <sstream>
 
 #include <nntrainer_error.h>
-#include <tensor.h>
+#include <random>
+
+// /**
+//  * @brief     get the seed
+//  * @return    seed
+//  */
+// unsigned int getSeed() { return 0; }
 
 namespace nntrainer {
 
@@ -64,11 +70,12 @@ inline void throw_status(int status) {
   }
 }
 
-/**
- * @brief     get the seed
- * @return    seed
- */
-unsigned int getSeed();
+static auto rng = [] {
+  std::mt19937 rng;
+  // rng.seed(getSeed());
+  rng.seed(0);
+  return rng;
+}();
 
 /**
  * @brief     sqrt function for float type
@@ -97,13 +104,6 @@ float logFloat(float x);
 float exp_util(float x);
 
 /**
- * @brief     rotate 180 dgree
- * @param[in] in input Tensor
- * @retVal Tensor rotated tensor (180 degree)
- */
-Tensor rotate_180(Tensor in);
-
-/**
  * @brief     Check Existance of File
  * @param[in] file path of the file to be checked
  * @returns   true if file exists, else false
index 5db6bce..a9c5cea 100755 (executable)
@@ -16,13 +16,13 @@ pushd $TARGET
 if [ ! -d builddir ]; then
     #default value of openblas num threads is 1 for android
     #enable-tflite-interpreter=false is just temporally until ci system is stabel
-  meson builddir -Dplatform=android -Dopenblas-num-threads=1 -Denable-tflite-interpreter=false
+  meson builddir -Dplatform=android -Dopenblas-num-threads=1 -Denable-tflite-interpreter=false -Denable-tflite-backbone=false
 else
   echo "warning: $TARGET/builddir has already been taken, this script tries to reconfigure and try building"
   pushd builddir
     #default value of openblas num threads is 1 for android
     #enable-tflite-interpreter=false is just temporally until ci system is stabel  
-    meson configure -Dplatform=android -Dopenblas-num-threads=1 -Denable-tflite-interpreter=false
+    meson configure -Dplatform=android -Dopenblas-num-threads=1 -Denable-tflite-interpreter=false -Denable-tflite-backbone=false
     meson --wipe
   popd
 fi