Restructure inner data inside Tensor
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 1 Jul 2020 06:14:56 +0000 (15:14 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 3 Jul 2020 04:07:53 +0000 (13:07 +0900)
**Changes proposed in this PR:**
- Change Tensor structure to enable sharing between Tensor
- Refactor Tensor ctors
- Add copy/move ctors & assignment operators

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/include/lazy_tensor.h
nntrainer/include/tensor.h
nntrainer/src/tensor.cpp

index dc4bafe..f0a810d 100644 (file)
@@ -40,7 +40,7 @@ public:
    * @brief Constructor of Lazy Tensor, Tensor is copied to gaurantee
    * immutability
    */
-  LazyTensor(const Tensor &from) { target.copy(from); };
+  LazyTensor(const Tensor &from) { target = Tensor(from); };
 
   /**
    * @brief     Wrapper method of add_i. see tensor.h for more detail
index 59cb763..aef0f0d 100644 (file)
@@ -31,13 +31,13 @@ extern "C" {
 }
 #endif
 
+#include <array>
 #include <cmath>
 #include <fstream>
 #include <iostream>
 #include <memory>
 #include <regex>
 #include <tensor_dim.h>
-#include <vector>
 
 namespace nntrainer {
 
@@ -48,23 +48,43 @@ class LazyTensor;
  */
 class Tensor {
 public:
-  /**
-   * @brief     Constructor of Tensor
-   */
-  Tensor() : _is_contiguous(true), dim(){};
+  Tensor(const TensorDim &d, float *buf = nullptr) :
+    dim(d),
+    strides{{1, 2, 3}},
+    is_contiguous(true),
+    data(new float[d.getDataLen()], std::default_delete<float[]>())
+    {
+    // todo: initialize appropriate strides
+    if (buf == nullptr) {
+      setZero();
+    } else {
+      float *data = getData();
+      unsigned int len = length();
+
+#ifdef USE_BLAS
+      cblas_scopy(len, buf, 1, data, 1);
+#else
+      for (unsigned int i = 0; i < len; ++i) {
+        data[i] = buf[i];
+      }
+#endif
+    }
+  }
 
   /**
-   * @brief     Constructor of Tensor with batch size one
-   * @param[in] dim TensorDim
+   * @brief     Basic Constructor of Tensor
    */
-  Tensor(const TensorDim dim);
+  Tensor() : Tensor(TensorDim()){};
 
   /**
-   * @brief     Constructor of Tensor with batch size one
+   * @brief     Constructor of Tensor
+   * @param[in] batch Batch of Tensor
+   * @param[in] channel Channel of Tensor
    * @param[in] heihgt Height of Tensor
    * @param[in] width Width of Tensor
    */
-  Tensor(int height, int width);
+  Tensor(int batch, int channel, int height, int width) :
+    Tensor(TensorDim(batch, channel, height, width)){};
 
   /**
    * @brief     Constructor of Tensor
@@ -72,46 +92,79 @@ public:
    * @param[in] heihgt Height of Tensor
    * @param[in] width Width of Tensor
    */
-  Tensor(int channel, int height, int width);
+  Tensor(int channel, int height, int width) :
+    Tensor(1, channel, height, width){};
 
   /**
-   * @brief     Constructor of Tensor
-   * @param[in] batch Batch of Tensor
-   * @param[in] channel Channel of Tensor
+   * @brief     Constructor of Tensor with batch size one
    * @param[in] heihgt Height of Tensor
    * @param[in] width Width of Tensor
    */
-  Tensor(int batch, int channel, int height, int width);
+  Tensor(int height, int width) : Tensor(1, 1, height, width){};
 
   /**
-   * @brief   Constructor of Tensor
-   * @param[in] d data for the Tensor with batch size one
+   * @brief     Constructor of Tensor
+   * @param[in] d data for the Tensor
    */
-  Tensor(std::vector<std::vector<float>> const &d);
+  Tensor(std::vector<std::vector<std::vector<std::vector<float>>>> const &d);
 
   /**
    * @brief     Constructor of Tensor
+   * @note      This constructor copies vector again. needs refactoring
    * @param[in] d data for the Tensor
    */
-  Tensor(std::vector<std::vector<std::vector<float>>> const &d);
+  Tensor(std::vector<std::vector<std::vector<float>>> const &d) :
+    Tensor(std::vector<std::decay<decltype(d)>::type>{d}){};
 
   /**
    * @brief     Constructor of Tensor
-   * @param[in] d data for the Tensor
+   * @note      This constructor copies vector again. needs refactoring
+   * @param[in] d data for the Tensor with batch size one
    */
-  Tensor(std::vector<std::vector<std::vector<std::vector<float>>>> const &d);
+  Tensor(std::vector<std::vector<float>> const &d) :
+    Tensor(std::vector<std::decay<decltype(d)>::type>{d}){};
+
+  /**
+   *  @brief  Copy constructor of Tensor.
+   *  @note This can be safely reverted to default
+   *        after checking using _data as a pointer is safe for functions using
+   * Tensor.
+   *  @param[in] Tensor &
+   */
+  Tensor(const Tensor &rhs) : Tensor(rhs.dim, rhs.data.get()){};
+
+  /**
+   *  @brief  Move constructor of Tensor.
+   *  @param[in] Tensor &&
+   */
+  Tensor(Tensor &&rhs) noexcept = default;
+
+  /**
+   * @brief  Copy assignment operator.
+   * @param[in] rhs Tensor to be copied.
+   */
+  // todo: refactor operator= to consider allocated size for the data
+  Tensor &operator=(const Tensor &rhs);
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs Tensor to be moved.
+   */
+  Tensor &operator=(Tensor &&rhs) noexcept;
+
+  void swap(Tensor &lhs, Tensor &rhs) noexcept;
 
   /**
    * @brief     Comparison operator overload
    * @param[in] rhs Tensor to be compared with
    */
-  bool operator== (const Tensor &rhs) const;
+  bool operator==(const Tensor &rhs) const;
 
   /**
    * @brief     Comparison operator overload
    * @param[in] rhs Tensor to be compared with
    */
-  bool operator!= (const Tensor &rhs) const { return !(*this == rhs); }
+  bool operator!=(const Tensor &rhs) const { return !(*this == rhs); }
 
   /**
    * @brief     return value at specific location
@@ -365,12 +418,16 @@ public:
   int getBatch() const { return dim.batch(); };
 
   /**
+   * @brief     Get length of current _data
+   * @retval    unsigned int length of the current _data
+   */
+  unsigned int length() const { return dim.getDataLen(); }
+
+  /**
    * @brief     Get size of the data
    * @retval    size_t Size in bytes
    */
-  size_t getSize() const {
-    return dim.getDataLen() * sizeof(decltype(data)::value_type);
-  }
+  size_t getSize() const { return length() * sizeof(float); }
 
   /**
    * @brief     Set the element value
@@ -408,7 +465,7 @@ public:
    * @param[in] from Tensor to be Copyed
    * @retval    Matix
    */
-  Tensor &copy(Tensor const &from);
+  Tensor &copy(Tensor &from);
 
   /**
    * @brief     Save the Tensor into file
@@ -462,9 +519,9 @@ public:
    * @brief     return Data pointer of Tensor
    * @retval    float pointer
    */
-  float *getData() { return data.data(); }
+  float *getData() { return data.get(); }
 
-  const float *getData() const { return data.data(); }
+  const float *getData() const { return data.get(); }
 
   /**
    * @brief     i data index
@@ -485,20 +542,20 @@ public:
    *            on this tensor
    * @retval    bool is contigous
    */
-  const bool isContiguous() const noexcept { return _is_contiguous; }
+  const bool isContiguous() const noexcept { return is_contiguous; }
 
   /**
    * @brief     return current stride of tensor.
    * @retval    int[MAXDIM] strides
    */
-  const int *strides() const noexcept { return _strides; }
+  const std::array<int, MAXDIM> getStrides() const noexcept { return strides; }
 
 private:
-  /**< handle the data as a std::vector type */
-  std::vector<float> data;
-  int _strides[MAXDIM];
-  bool _is_contiguous;
+  /**< handle the data as a std::shared_ptr<float> type */
   TensorDim dim;
+  std::array<int, MAXDIM> strides;
+  bool is_contiguous;
+  std::shared_ptr<float> data;
 
   static constexpr float min_limits = std::numeric_limits<float>::min();
   static constexpr float max_limits = std::numeric_limits<float>::max();
index d13b606..c71ee33 100644 (file)
@@ -61,50 +61,47 @@ static auto rng = [] {
   return rng;
 }();
 
-Tensor::Tensor(const TensorDim d) {
-  dim = d;
-  this->data = std::vector<float>(dim.getDataLen());
-  _is_contiguous = true;
-  setZero();
+Tensor &Tensor::operator=(const Tensor &rhs) {
+  using std::swap;
+  
+  Tensor tmp(rhs);
+  swap(*this, tmp);
+  return *this;
 }
 
-Tensor::Tensor(int height, int width) {
-  dim.height(height);
-  dim.width(width);
-  this->data = std::vector<float>(dim.getDataLen());
-  _is_contiguous = true;
-  setZero();
-}
+Tensor &Tensor::operator=(Tensor &&rhs) noexcept {
+  using std::swap;
+
+  std::swap(dim, rhs.dim);
+  std::swap(data, rhs.data);
+  std::swap(strides, rhs.strides);
+  std::swap(is_contiguous, rhs.is_contiguous);
 
-Tensor::Tensor(int channel, int height, int width) {
-  dim.height(height);
-  dim.width(width);
-  dim.batch(channel);
-  this->data = std::vector<float>(dim.getDataLen());
-  _is_contiguous = true;
-  setZero();
+  return *this;
 }
 
-Tensor::Tensor(int batch, int channel, int height, int width) {
-  dim.height(height);
-  dim.width(width);
-  dim.batch(batch);
-  dim.channel(channel);
-  this->data = std::vector<float>(dim.getDataLen());
-  _is_contiguous = true;
-  setZero();
+void Tensor::swap(Tensor &lhs, Tensor &rhs) noexcept {
+  std::swap(lhs.dim, rhs.dim);
+  std::swap(lhs.data, rhs.data);
+  std::swap(lhs.strides, rhs.strides);
+  std::swap(lhs.is_contiguous, rhs.is_contiguous);
 }
 
-bool Tensor::operator== (const Tensor &rhs) const {
+bool Tensor::operator==(const Tensor &rhs) const {
   if (this->dim != rhs.dim)
     return false;
 
-  if (data.size() != rhs.data.size())
+  size_t len = length();
+
+  if (len != rhs.length())
     return false;
 
-  for (size_t i = 0; i < data.size(); ++i) {
-    if (std::isnan(data[i]) || std::isnan(rhs.data[i]) ||
-        std::fabs(data[i] - rhs.data[i]) > epsilon)
+  const float *data = getData();
+  const float *rdata = rhs.getData();
+
+  for (size_t i = 0; i < len; ++i) {
+    if (std::isnan(data[i]) || std::isnan(rdata[i]) ||
+        std::fabs(data[i] - rdata[i]) > epsilon)
       return false;
   }
 
@@ -113,22 +110,24 @@ bool Tensor::operator== (const Tensor &rhs) const {
 
 float Tensor::getValue(unsigned int batch, unsigned int c, unsigned int h,
                        unsigned int w) const {
-  return this->data[batch * dim.channel() * dim.height() * dim.width() +
-                    c * dim.height() * dim.width() + h * dim.width() + w];
+  return getData()[batch * dim.getFeatureLen() +
+                   c * dim.height() * dim.width() + h * dim.width() + w];
 }
 
 void Tensor::setValue(unsigned int batch, unsigned int c, unsigned int h,
                       unsigned int w, float value) {
-  if(!_is_contiguous) {
+  if (!is_contiguous) {
     throw std::runtime_error("cannot set value of non-contiguous tensor");
   }
 
-  this->data[batch * dim.channel() * dim.height() * dim.width() +
-             c * dim.height() * dim.width() + h * dim.width() + w] = value;
+  getData()[batch * dim.getFeatureLen() +
+            c * dim.height() * dim.width() + h * dim.width() + w] = value;
 }
 
 template <typename T> void Tensor::setDist(T dist) {
-  for (unsigned int i = 0; i < dim.getDataLen(); ++i) {
+  float *data = getData();
+  unsigned int len = length();
+  for (unsigned int i = 0; i < len; ++i) {
     data[i] = dist(rng);
   }
 }
@@ -143,39 +142,21 @@ void Tensor::setRandUniform(float min, float max) {
     std::uniform_real_distribution<float>(min, max));
 }
 
-Tensor::Tensor(std::vector<std::vector<float>> const &d) {
-  dim.height(d.size());
-  dim.width(d[0].size());
-  this->data = std::vector<float>(dim.getDataLen());
-  _is_contiguous = true;
-
-  for (unsigned int j = 0; j < dim.height(); ++j)
-    for (unsigned int k = 0; k < dim.width(); ++k)
-      this->setValue(0, 0, j, k, d[j][k]);
-}
-
-Tensor::Tensor(std::vector<std::vector<std::vector<float>>> const &d) {
-  dim.channel(d.size());
-  dim.height(d[0].size());
-  dim.width(d[0][0].size());
-  this->data = std::vector<float>(dim.getDataLen());
-  _is_contiguous = true;
-
-  for (unsigned int j = 0; j < dim.channel(); ++j)
-    for (unsigned int k = 0; k < dim.height(); ++k)
-      for (unsigned int l = 0; l < dim.width(); ++l)
-        this->setValue(0, j, k, l, d[j][k][l]);
-}
-
 Tensor::Tensor(
   std::vector<std::vector<std::vector<std::vector<float>>>> const &d) {
 
+  if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
+    throw std::out_of_range(
+      "[Tensor] trying to initialize Tensor from empty vector");
+  }
+
   dim.batch(d.size());
   dim.channel(d[0].size());
   dim.height(d[0][0].size());
   dim.width(d[0][0][0].size());
-  this->data = std::vector<float>(dim.getDataLen());
-  _is_contiguous = true;
+  data = std::shared_ptr<float>(new float[dim.getDataLen()],
+                                 std::default_delete<float[]>());
+  is_contiguous = true;
 
   for (unsigned int i = 0; i < dim.batch(); ++i)
     for (unsigned int j = 0; j < dim.channel(); ++j)
@@ -185,19 +166,22 @@ Tensor::Tensor(
 }
 
 int Tensor::multiply_i(float const &value) {
+
+  float *data = getData();
+  unsigned int len = length();
+
 #ifdef USE_BLAS
-  cblas_sscal(dim.getDataLen(), value, this->data.data(), 1);
+  cblas_sscal(len, value, data, 1);
 #else
-  for (unsigned int k = 0; k < dim.getDataLen(); ++k) {
-    this->data[k] *= value;
+  for (unsigned int k = 0; k < len; ++k) {
+    data[k] *= value;
   }
 #endif
   return ML_ERROR_NONE;
 }
 
 Tensor Tensor::multiply(float const &value) {
-  Tensor result(dim);
-  result.copy(*this);
+  Tensor result(*this);
   result.multiply_i(value);
 
   return result;
@@ -212,27 +196,25 @@ int Tensor::divide_i(float const &value) {
 }
 
 Tensor Tensor::divide(float const &value) {
-  Tensor result(dim);
   if (value == 0.0) {
     throw std::runtime_error("Error: Divide by zero");
   }
-
-  result.copy(*this);
+  Tensor result(*this);
   result.divide_i(value);
 
   return result;
 }
 
 int Tensor::add_i(float const &value) {
+  float *data = getData();
+  unsigned int len = length();
 #ifdef USE_BLAS
   Tensor tmp(dim);
-  for (unsigned int i = 0; i < tmp.dim.getDataLen(); ++i)
-    tmp.data[i] = 1.0;
-  cblas_saxpy(dim.getDataLen(), value, tmp.data.data(), 1, this->data.data(),
-              1);
+  tmp.setValue(1.0);
+  cblas_saxpy(len, value, tmp.getData(), 1, data, 1);
 #else
-  for (unsigned int k = 0; k < dim.getDataLen(); ++k) {
-    this->data[k] = this->data[k] + value;
+  for (unsigned int k = 0; k < len; ++k) {
+    data[k] += value;
   }
 #endif
 
@@ -240,9 +222,7 @@ int Tensor::add_i(float const &value) {
 }
 
 Tensor Tensor::add(float const &value) {
-  Tensor result(dim);
-
-  result.copy(*this);
+  Tensor result(*this);
   result.add_i(value);
 
   return result;
@@ -259,16 +239,18 @@ int Tensor::add_i(Tensor const &m, float const alpha) {
     return ML_ERROR_INVALID_PARAMETER;
   }
 
+  float *data = getData();
+  const float *mdata = m.getData();
+  unsigned int len = length();
+
 #ifdef USE_BLAS
   unsigned int size = dim.width() * dim.height() * dim.channel();
   if (m.dim.batch() == 1) {
     for (unsigned int k = 0; k < dim.batch(); ++k) {
-      cblas_saxpy(size, alpha, m.data.data(), 1, &(this->data.data()[k * size]),
-                  1);
+      cblas_saxpy(size, alpha, mdata, 1, &(data[k * size]), 1);
     }
   } else {
-    cblas_saxpy(dim.getDataLen(), alpha, m.data.data(), 1, this->data.data(),
-                1);
+    cblas_saxpy(len, alpha, mdata, 1, data, 1);
   }
 #else
   unsigned int i, j, k;
@@ -276,12 +258,12 @@ int Tensor::add_i(Tensor const &m, float const alpha) {
     for (k = 0; k < dim.batch(); ++k) {
       for (i = 0; i < m.dim.getFeatureLen(); ++i) {
         j = k * m.dim.getFeatureLen();
-        this->data[j + i] += alpha * m.data[i];
+        data[j + i] += alpha * mdata[i];
       }
     }
   } else {
-    for (k = 0; k < dim.getDataLen(); ++k) {
-      this->data[k] += alpha * m.data[k];
+    for (k = 0; k < len; ++k) {
+      data[k] += alpha * mdata[k];
     }
   }
 #endif
@@ -294,8 +276,7 @@ Tensor Tensor::add(Tensor const &m, float const alpha) const {
     throw std::runtime_error("Error: Dimension must be equal each other");
   }
 
-  Tensor result(dim);
-  result.copy(*this);
+  Tensor result(*this);
   result.add_i(m, alpha);
 
   return result;
@@ -311,6 +292,10 @@ int Tensor::subtract_i(Tensor const &m) {
     return ML_ERROR_INVALID_PARAMETER;
   }
 
+  float *data = getData();
+  const float *mdata = m.getData();
+  unsigned int len = length();
+
 #ifdef USE_BLAS
   unsigned int size =
     this->dim.channel() * this->dim.width() * this->dim.height();
@@ -318,27 +303,24 @@ int Tensor::subtract_i(Tensor const &m) {
 
   if (m.dim.batch() == 1) {
     for (unsigned int k = 0; k < dim.batch(); ++k) {
-      cblas_saxpy(size, alpha, m.data.data(), 1, &(this->data.data()[k * size]),
-                  1);
+      cblas_saxpy(size, alpha, mdata, 1, &(data[k * size]), 1);
     }
   } else {
-    cblas_saxpy(dim.getDataLen(), alpha, m.data.data(), 1, this->data.data(),
-                1);
+    cblas_saxpy(len, alpha, mdata, 1, data, 1);
   }
 #else
-  unsigned int i, j, k, len, dlen;
+  unsigned int i, j, k, dlen;
   if (m.dim.batch() == 1) {
     len = m.dim.getFeatureLen();
     for (k = 0; k < dim.batch(); ++k) {
       for (i = 0; i < len; ++i) {
         j = k * len;
-        this->data[j + i] = this->data[j + i] - m.data[i];
+        data[j + i] -= mdata[i];
       }
     }
   } else {
-    dlen = m.dim.getDataLen();
-    for (k = 0; k < dlen; ++k) {
-      this->data[k] = data[k] - m.data[k];
+    for (k = 0; k < len; ++k) {
+      data[k] -= mdata[k];
     }
   }
 #endif
@@ -352,8 +334,7 @@ Tensor Tensor::subtract(Tensor const &m) const {
     throw std::runtime_error("Error: Dimension must be equal each other");
   }
 
-  Tensor result(dim);
-  result.copy(*this);
+  Tensor result(*this);
   result.subtract_i(m);
 
   return result;
@@ -362,9 +343,8 @@ Tensor Tensor::subtract(Tensor const &m) const {
 int Tensor::subtract_i(float const &value) { return this->add_i(-value); }
 
 Tensor Tensor::subtract(float const &value) {
-  Tensor result(dim);
+  Tensor result(*this);
 
-  result.copy(*this);
   if (result.subtract_i(value) != ML_ERROR_NONE) {
     throw std::runtime_error("Error: there was an error on subtraction");
   }
@@ -378,30 +358,34 @@ int Tensor::multiply_i(Tensor const &m) {
     return ML_ERROR_INVALID_PARAMETER;
   }
 
-  int end = dim.getDataLen() / 4;
-  int e = dim.getFeatureLen() / 4;
-  int i;
+  float *data = getData();
+  const float *mdata = m.getData();
+
+  unsigned int len = length();
+  unsigned int end = len / 4;
+  unsigned int e = dim.getFeatureLen() / 4;
+  unsigned int i;
   if (m.dim.batch() == 1) {
     for (unsigned int k = 0; k < dim.batch(); ++k) {
       int b = k * dim.getFeatureLen();
       for (i = 0; i < e * 4; i += 4) {
-        this->data[b + i + 0] *= m.data[i + 0];
-        this->data[b + i + 1] *= m.data[i + 1];
-        this->data[b + i + 2] *= m.data[i + 2];
-        this->data[b + i + 3] *= m.data[i + 3];
+        data[b + i + 0] *= mdata[i + 0];
+        data[b + i + 1] *= mdata[i + 1];
+        data[b + i + 2] *= mdata[i + 2];
+        data[b + i + 3] *= mdata[i + 3];
       }
       for (unsigned int j = i; j < dim.getFeatureLen(); j++)
-        this->data[b + j] = this->data[b + j] * m.data[j];
+        data[b + j] *= mdata[j];
     }
   } else {
     for (i = 0; i < end * 4; i += 4) {
-      this->data[i + 0] *= m.data[i + 0];
-      this->data[i + 1] *= m.data[i + 1];
-      this->data[i + 2] *= m.data[i + 2];
-      this->data[i + 3] *= m.data[i + 3];
+      data[i + 0] *= mdata[i + 0];
+      data[i + 1] *= mdata[i + 1];
+      data[i + 2] *= mdata[i + 2];
+      data[i + 3] *= mdata[i + 3];
     }
-    for (unsigned int j = i; j < dim.getDataLen(); ++j)
-      this->data[j] = this->data[j] * m.data[j];
+    for (unsigned int j = i; j < len; ++j)
+      data[j] *= mdata[j];
   }
 
   return ML_ERROR_NONE;
@@ -413,8 +397,7 @@ Tensor Tensor::multiply(Tensor const &m) const {
     throw std::runtime_error("Error: Dimension must be equal each other");
   }
 
-  Tensor result(dim);
-  result.copy(*this);
+  Tensor result(*this);
   result.multiply_i(m);
 
   return result;
@@ -426,7 +409,11 @@ int Tensor::divide_i(Tensor const &m) {
     return ML_ERROR_INVALID_PARAMETER;
   }
 
-  unsigned int end = dim.getDataLen() / 4;
+  float *data = getData();
+  const float *mdata = m.getData();
+
+  unsigned int len = length();
+  unsigned int end = len / 4;
   unsigned int e = dim.getFeatureLen() / 4;
   unsigned int i, j, k;
 
@@ -435,23 +422,23 @@ int Tensor::divide_i(Tensor const &m) {
     for (k = 0; k < dim.batch(); ++k) {
       unsigned int b = k * dim.getFeatureLen();
       for (i = 0; i < e * 4; i += 4) {
-        this->data[b + i + 0] /= m.data[i + 0];
-        this->data[b + i + 1] /= m.data[i + 1];
-        this->data[b + i + 2] /= m.data[i + 2];
-        this->data[b + i + 3] /= m.data[i + 3];
+        data[b + i + 0] /= mdata[i + 0];
+        data[b + i + 1] /= mdata[i + 1];
+        data[b + i + 2] /= mdata[i + 2];
+        data[b + i + 3] /= mdata[i + 3];
       }
       for (unsigned int j = i; j < dim.getFeatureLen(); ++j)
-        this->data[b + j] /= m.data[j];
+        data[b + j] /= mdata[j];
     }
   } else {
     for (i = 0; i < end * 4; i += 4) {
-      this->data[i + 0] /= m.data[i + 0];
-      this->data[i + 1] /= m.data[i + 1];
-      this->data[i + 2] /= m.data[i + 2];
-      this->data[i + 3] /= m.data[i + 3];
+      data[i + 0] /= mdata[i + 0];
+      data[i + 1] /= mdata[i + 1];
+      data[i + 2] /= mdata[i + 2];
+      data[i + 3] /= mdata[i + 3];
     }
-    for (j = i; j < dim.getDataLen(); ++j)
-      this->data[j] /= m.data[j];
+    for (j = i; j < len; ++j)
+      data[j] /= mdata[j];
   }
 
   return ML_ERROR_NONE;
@@ -463,9 +450,7 @@ Tensor Tensor::divide(Tensor const &m) const {
     throw std::runtime_error("Error: Dimension must be equal each other");
   }
 
-  Tensor result(dim.batch(), dim.channel(), dim.height(), dim.width());
-
-  result.copy(*this);
+  Tensor result(*this);
   result.divide_i(m);
 
   return result;
@@ -478,17 +463,21 @@ Tensor Tensor::divide(Tensor const &m) const {
 Tensor Tensor::sum_by_batch() const {
   unsigned int k;
   Tensor ret(dim.batch(), 1, 1, 1);
+
+  const float *data = getData();
+  float *rdata = ret.getData();
+
 #ifdef USE_BLAS
   for (k = 0; k < dim.batch(); ++k)
-    ret.data[k] = cblas_sasum(dim.getFeatureLen(),
-                              &(data.data()[k * dim.getFeatureLen()]), 1);
+    rdata[k] =
+      cblas_sasum(dim.getFeatureLen(), &(data[k * dim.getFeatureLen()]), 1);
 #else
   unsigned int i;
   for (k = 0; k < dim.batch(); ++k) {
     unsigned int id = k * dim.getFeatureLen();
-    ret.data[k] = 0.0;
+    rdata[k] = 0.0;
     for (i = 0; i < dim.getFeatureLen(); ++i) {
-      ret.data[k] += data[id + i];
+      rdata[k] += data[id + i];
     }
   }
 #endif
@@ -501,10 +490,13 @@ Tensor Tensor::sum_by_batch() const {
  */
 Tensor Tensor::sum(int axis) const {
   Tensor ret;
+
+  const float *data = getData();
+
   switch (axis) {
   case 0: {
-
     ret = Tensor(1, dim.channel(), dim.height(), dim.width());
+    float *rdata = ret.getData();
     for (unsigned int l = 0; l < dim.channel(); ++l) {
       unsigned int L = l * dim.width() * dim.height();
       for (unsigned int i = 0; i < dim.height(); ++i) {
@@ -512,7 +504,7 @@ Tensor Tensor::sum(int axis) const {
         for (unsigned int j = 0; j < dim.width(); ++j) {
           for (unsigned int k = 0; k < dim.batch(); ++k) {
             unsigned int K = k * dim.getFeatureLen();
-            ret.data[L + I + j] += data[K + L + I + j];
+            rdata[L + I + j] += data[K + L + I + j];
           }
         }
       }
@@ -520,6 +512,7 @@ Tensor Tensor::sum(int axis) const {
   } break;
   case 1: {
     ret = Tensor(dim.batch(), 1, dim.height(), dim.width());
+    float *rdata = ret.getData();
     for (unsigned int l = 0; l < dim.batch(); ++l) {
       unsigned int L = dim.width() * dim.height() * l;
       unsigned int LL = l * dim.getFeatureLen();
@@ -528,7 +521,7 @@ Tensor Tensor::sum(int axis) const {
         for (unsigned int i = 0; i < dim.width(); ++i) {
           for (unsigned int k = 0; k < dim.channel(); ++k) {
             unsigned int K = k * dim.width() * dim.height();
-            ret.data[(L + J + i)] += data[LL + K + J + i];
+            rdata[L + J + i] += data[LL + K + J + i];
           }
         }
       }
@@ -536,6 +529,7 @@ Tensor Tensor::sum(int axis) const {
   } break;
   case 2: {
     ret = Tensor(dim.batch(), dim.channel(), 1, dim.width());
+    float *rdata = ret.getData();
     for (unsigned int k = 0; k < dim.batch(); ++k) {
       unsigned int K = k * dim.channel() * dim.width();
       unsigned int KK = k * dim.getFeatureLen();
@@ -545,7 +539,7 @@ Tensor Tensor::sum(int axis) const {
         for (unsigned int j = 0; j < dim.width(); ++j) {
           for (unsigned int i = 0; i < dim.height(); ++i) {
             unsigned int I = i * dim.width();
-            ret.data[K + L + j] += data[KK + LL + j + I];
+            rdata[K + L + j] += data[KK + LL + j + I];
           }
         }
       }
@@ -553,6 +547,7 @@ Tensor Tensor::sum(int axis) const {
   } break;
   case 3: {
     ret = Tensor(dim.batch(), dim.channel(), dim.height(), 1);
+    float *rdata = ret.getData();
     for (unsigned int k = 0; k < dim.batch(); ++k) {
       unsigned int K = k * dim.channel() * dim.height();
       unsigned int KK = k * dim.getFeatureLen();
@@ -562,7 +557,7 @@ Tensor Tensor::sum(int axis) const {
         for (unsigned int i = 0; i < dim.height(); ++i) {
           unsigned int II = i * dim.width();
           for (unsigned int j = 0; j < dim.width(); ++j) {
-            ret.data[K + L + i] += data[KK + LL + II + j];
+            rdata[K + L + i] += data[KK + LL + II + j];
           }
         }
       }
@@ -592,6 +587,10 @@ Tensor Tensor::dot(Tensor const &m) const {
   int mwidth = m.dim.width();
   Tensor result(dim.batch(), 1, dim.height(), mwidth);
 
+  const float *data = getData();
+  const float *mdata = m.getData();
+  float *rdata = result.getData();
+
 #ifdef USE_BLAS
   float alpha_dgemm = 1.0;
   float beta_dgemm = 1.0;
@@ -600,9 +599,9 @@ Tensor Tensor::dot(Tensor const &m) const {
       unsigned int i = k * dim.width() * dim.height();
       unsigned int ii = k * dim.height() * m.dim.width();
       cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, dim.height(),
-                  m.dim.width(), dim.width(), alpha_dgemm, &(data.data()[i]),
-                  dim.width(), m.data.data(), m.dim.width(), beta_dgemm,
-                  &(result.data.data()[ii]), m.dim.width());
+                  m.dim.width(), dim.width(), alpha_dgemm, &(data[i]),
+                  dim.width(), mdata, m.dim.width(), beta_dgemm, &(rdata[ii]),
+                  m.dim.width());
     }
   } else {
     for (unsigned int k = 0; k < dim.batch(); k++) {
@@ -611,9 +610,9 @@ Tensor Tensor::dot(Tensor const &m) const {
       unsigned int ii = k * dim.height() * m.dim.width();
 
       cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, dim.height(),
-                  m.dim.width(), dim.width(), alpha_dgemm, &(data.data()[i]),
-                  dim.width(), &(m.data.data()[j]), m.dim.width(), beta_dgemm,
-                  &(result.data.data()[ii]), m.dim.width());
+                  m.dim.width(), dim.width(), alpha_dgemm, &(data[i]),
+                  dim.width(), &(mdata[j]), m.dim.width(), beta_dgemm,
+                  &(rdata[ii]), m.dim.width());
     }
   }
 #elif USE_CUBLAS
@@ -634,8 +633,8 @@ Tensor Tensor::dot(Tensor const &m) const {
 
       cudaMalloc((void **)&d_A, size_A);
       cudaMalloc((void **)&d_B, size_B);
-      cudaMemcpy(d_A, &data.data()[i], size_A, cudaMemcpyHostToDevice);
-      cudaMemcpy(d_B, m.data.data(), size_B, cudaMemcpyHostToDevice);
+      cudaMemcpy(d_A, &data[i], size_A, cudaMemcpyHostToDevice);
+      cudaMemcpy(d_B, mdata, size_B, cudaMemcpyHostToDevice);
       cudaMalloc((void **)&d_C, size_C);
 
       {
@@ -649,8 +648,7 @@ Tensor Tensor::dot(Tensor const &m) const {
                      dim.height(), dim.width(), &alpha, d_B, m.dim.width(), d_A,
                      dim.width(), &beta, d_C, m.dim.width()));
 
-        (cudaMemcpy(&result.data.data()[ii], d_C, size_C,
-                    cudaMemcpyDeviceToHost));
+        (cudaMemcpy(&rdata[ii], d_C, size_C, cudaMemcpyDeviceToHost));
         (cublasDestroy(handle));
       }
     }
@@ -662,8 +660,8 @@ Tensor Tensor::dot(Tensor const &m) const {
 
       (cudaMalloc((void **)&d_A, size_A));
       (cudaMalloc((void **)&d_B, size_B));
-      (cudaMemcpy(d_A, &data.data()[i], size_A, cudaMemcpyHostToDevice));
-      (cudaMemcpy(d_B, &m.data.data()[j], size_B, cudaMemcpyHostToDevice));
+      (cudaMemcpy(d_A, &data[i], size_A, cudaMemcpyHostToDevice));
+      (cudaMemcpy(d_B, &mdata[j], size_B, cudaMemcpyHostToDevice));
       (cudaMalloc((void **)&d_C, size_C));
 
       {
@@ -677,8 +675,7 @@ Tensor Tensor::dot(Tensor const &m) const {
                      dim.height(), dim.width(), &alpha, d_B, m.dim.width(), d_A,
                      dim.width(), &beta, d_C, m.dim.width()));
 
-        (cudaMemcpy(&result.data.data()[ii], d_C, size_C,
-                    cudaMemcpyDeviceToHost));
+        (cudaMemcpy(&rdata[ii], d_C, size_C, cudaMemcpyDeviceToHost));
         (cublasDestroy(handle));
       }
     }
@@ -692,10 +689,9 @@ Tensor Tensor::dot(Tensor const &m) const {
         for (j = 0; j < m.dim.width(); ++j) {
           for (h = 0; h < dim.width(); ++h) {
             w += data[k * dim.height() * dim.width() + i * dim.width() + h] *
-                 m.data[h * m.dim.width() + j];
+                 mdata[h * m.dim.width() + j];
           }
-          result
-            .data[k * dim.height() * m.dim.width() + i * m.dim.width() + j] = w;
+          rdata[k * dim.height() * m.dim.width() + i * m.dim.width() + j] = w;
           w = 0.0;
         }
       }
@@ -705,12 +701,10 @@ Tensor Tensor::dot(Tensor const &m) const {
       for (i = 0; i < dim.height(); i++) {
         for (j = 0; j < m.dim.width(); j++) {
           for (h = 0; h < dim.width(); h++) {
-            w +=
-              data[k * dim.height() * dim.width() + i * dim.width() + h] *
-              m.data[k * dim.width() * m.dim.width() + h * m.dim.width() + j];
+            w += data[k * dim.height() * dim.width() + i * dim.width() + h] *
+                 mdata[k * dim.width() * m.dim.width() + h * m.dim.width() + j];
           }
-          result
-            .data[k * dim.height() * m.dim.width() + i * m.dim.width() + j] = w;
+          rdata[k * dim.height() * m.dim.width() + i * m.dim.width() + j] = w;
           w = 0.0;
         }
       }
@@ -742,7 +736,7 @@ Tensor Tensor::transpose(std::string direction) const {
 
   SL = fromDim[0], SI = fromDim[1], SJ = fromDim[2], SK = fromDim[3];
 
-  inptr = data.data();
+  inptr = getData();
   outptr = result.getData();
 
   switch (indexI) {
@@ -776,8 +770,12 @@ Tensor Tensor::apply(std::function<float(float)> f) const {
   Tensor result(dim.batch(), dim.channel(), dim.height(), dim.width());
   unsigned int i;
 
-  for (i = 0; i < dim.getDataLen(); ++i)
-    result.data[i] = f(data[i]);
+  const float *data = getData();
+  float *rdata = result.getData();
+  unsigned int len = length();
+
+  for (i = 0; i < len; ++i)
+    rdata[i] = f(data[i]);
 
   return result;
 }
@@ -787,6 +785,8 @@ Tensor Tensor::apply(std::function<Tensor(Tensor)> f) const { return f(*this); }
 void Tensor::print(std::ostream &out) const {
   unsigned int i, j, k, l;
   std::stringstream ss;
+
+  const float *data = getData();
   for (k = 0; k < dim.batch(); k++) {
     for (l = 0; l < dim.channel(); l++) {
       for (i = 0; i < dim.height(); i++) {
@@ -813,30 +813,17 @@ float *Tensor::getAddress(unsigned int i) {
     ml_loge("Error: Index out of bounds");
     return nullptr;
   }
-  return &data[i];
+
+  return &getData()[i];
 }
 
-Tensor &Tensor::copy(const Tensor &from) {
+Tensor &Tensor::copy(Tensor &from) {
   // todo: enable copy to non-contiguous tensor
-  if(!_is_contiguous) {
+  if (!is_contiguous) {
     throw std::runtime_error("Cannot copy non-contiguous tensor");
   }
 
-  if (this != &from && from.dim.getDataLen() != 0) {
-    dim.channel(from.dim.channel());
-    dim.height(from.dim.height());
-    dim.width(from.dim.width());
-    dim.batch(from.dim.batch());
-    if (this->data.empty()) {
-      this->data.resize(from.data.size());
-    }
-#ifdef USE_BLAS
-    cblas_scopy(dim.getDataLen(), from.data.data(), 1, this->data.data(), 1);
-#else
-    for (unsigned int i = 0; i < dim.getDataLen(); ++i)
-      data[i] = from.data[i];
-#endif
-  }
+  *this = from;
 
   return *this;
 }
@@ -852,13 +839,11 @@ int Tensor::setDim(TensorDim d) {
 }
 
 void Tensor::save(std::ofstream &file) {
-  for (unsigned int i = 0; i < dim.getDataLen(); i++)
-    file.write((char *)&data[i], sizeof(float));
+  file.write((char *)getData(), getSize());
 }
 
 void Tensor::read(std::ifstream &file) {
-  for (unsigned int i = 0; i < dim.getDataLen(); i++)
-    file.read((char *)&data[i], sizeof(float));
+  file.read((char *)getData(), getSize());
 }
 
 /**
@@ -872,23 +857,29 @@ Tensor Tensor::average(int axis) const {
   TensorDim out_dim = dim;
   out_dim.setTensorDim(axis, 1);
 
-  Tensor result(out_dim);
-  result = this->sum(axis);
+  Tensor result;
+  result = std::move(this->sum(axis));
   result.divide_i(dim.batch());
 
   return result;
 }
 
-void Tensor::setValue(float val) { std::fill(data.begin(), data.end(), val); }
+void Tensor::setValue(float val) {
+  float *data = getData();
+  std::fill(data, data + length(), val);
+}
 
 void Tensor::setZero() { setValue(0); }
 
 int Tensor::argmax() {
   int index = 0;
   float maximum = min_limits;
-  for (unsigned int i = 0; i < dim.getDataLen(); i++) {
-    if (this->data[i] > maximum) {
-      maximum = this->data[i];
+  float *data = getData();
+  unsigned int len = length();
+
+  for (unsigned int i = 0; i < len; i++) {
+    if (data[i] > maximum) {
+      maximum = data[i];
       index = i;
     }
   }
@@ -896,10 +887,11 @@ int Tensor::argmax() {
 }
 
 float Tensor::l2norm() const {
-  unsigned int len = dim.getDataLen();
+  unsigned int len = length();
+  const float *data = getData();
 
 #ifdef USE_BLAS
-  return cblas_snrm2(len, this->getData(), 1);
+  return cblas_snrm2(len, data, 1);
 #else
   // fix me: to the version that does not allow overflow
   float sum = 0.0;
@@ -918,6 +910,8 @@ Tensor Tensor::normalization() const {
   float Min = max_limits;
   float Max = min_limits;
 
+  const float *data = getData();
+
   for (unsigned int k = 0; k < dim.batch(); ++k) {
     for (unsigned int l = 0; l < dim.channel(); ++l) {
       for (unsigned int i = 0; i < dim.height(); ++i) {
@@ -925,10 +919,10 @@ Tensor Tensor::normalization() const {
           unsigned int id = k * dim.getFeatureLen() +
                             l * dim.height() * dim.width() + i * dim.width() +
                             j;
-          if (this->data[id] < Min)
-            Min = this->data[id];
-          if (this->data[id] > Max)
-            Max = this->data[id];
+          if (data[id] < Min)
+            Min = data[id];
+          if (data[id] > Max)
+            Max = data[id];
         }
       }
     }
@@ -946,6 +940,9 @@ LazyTensor Tensor::chain() const { return LazyTensor(*this); }
 Tensor Tensor::standardization() const {
   Tensor result(dim);
 
+  const float *data = getData();
+  float *rdata = result.getData();
+
   for (unsigned int k = 0; k < dim.batch(); ++k) {
     int K = k * dim.getFeatureLen();
     float mean;
@@ -959,7 +956,7 @@ Tensor Tensor::standardization() const {
         unsigned int I = L + i * dim.width();
         for (unsigned int j = 0; j < dim.width(); ++j) {
           unsigned int J = I + j;
-          mean_tmp += this->data[J];
+          mean_tmp += data[J];
         }
       }
     }
@@ -972,7 +969,7 @@ Tensor Tensor::standardization() const {
         unsigned int I = L + i * dim.width();
         for (unsigned int j = 0; j < dim.width(); ++j) {
           unsigned int J = I + j;
-          std_tmp += (this->data[J] - mean) * (this->data[J] - mean);
+          std_tmp += (data[J] - mean) * (data[J] - mean);
         }
       }
     }
@@ -984,7 +981,7 @@ Tensor Tensor::standardization() const {
         unsigned int I = L + i * dim.width();
         for (unsigned int j = 0; j < dim.width(); ++j) {
           unsigned int J = I + j;
-          result.data[J] = (this->data[J] - mean) / std_dev;
+          rdata[J] = (data[J] - mean) / std_dev;
         }
       }
     }