[Tensor] Quantized Tensor (Int 8) with Scale
authorDonghyeon Jeong <dhyeon.jeong@samsung.com>
Wed, 30 Aug 2023 23:19:30 +0000 (08:19 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 12 Sep 2023 01:53:01 +0000 (10:53 +0900)
- Quantized Tensor is now present with Int 8 with scale.
- Dequantization is performed by multiplying values by a scaling factor for channels.
- Only read, write, and dequantization operations are allowed.

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
api/ccapi/include/tensor_dim.h
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h
nntrainer/tensor/tensor_dim.cpp
nntrainer/utils/base_properties.h
test/unittest/unittest_nntrainer_tensor.cpp
test/unittest/unittest_nntrainer_tensor_nhwc.cpp
test/unittest/unittest_nntrainer_tensor_pool.cpp

index 93a3560..0bf24be 100644 (file)
@@ -48,12 +48,13 @@ public:
   enum class Format { NCHW, NHWC };
 
   /**
-   * @brief Tensor Data Type. Currently FP16 & FP32 Support
+   * @brief Tensor Data Type. Currently QINT8, FP16 & FP32 Support
    *
    */
   enum class DataType {
-    FP16, /** half precion */
-    FP32  /** single precision */
+    QINT8, /** quantized int 8*/
+    FP16,  /** half precision */
+    FP32   /** single precision */
   };
 
   /**
@@ -93,7 +94,7 @@ public:
    * @brief     Creator of TensorDim with Format & DataType
    *
    * @param fm format NCHW | HNWC
-   * @param fm DataType FP16 | FP32
+   * @param fm DataType QINT8 | FP16 | FP32
    * @param eff_dim_flag_ effective dimension flag (1 means it's effective)
    * @param dyn_dim_flag_ dynamic dimension flag (1 means it's unspecified)
    */
@@ -157,7 +158,7 @@ public:
    * @param h height
    * @param w width
    * @param fm format NCHW | HNWC
-   * @param d_type Data Type FP16 | FP32
+   * @param d_type Data Type QINT8 | FP16 | FP32
    * @param eff_dim_flag_ dimension bit flag to calculate the dynamic
    * dimension, rightmost is width
    */
@@ -186,7 +187,7 @@ public:
    *
    * @param shape shape of format
    * @param fm format NCHW | HNWC
-   * @param d_type data type FP16 | FP32
+   * @param d_type data type QINT8 | FP16 | FP32
    */
   TensorDim(const std::string &shape, TensorDim::Format fm,
             TensorDim::DataType d_type = TensorDim::DataType::FP32);
index f965029..8f3d89d 100644 (file)
@@ -127,7 +127,8 @@ public:
   SrcSharedTensor() : src(nullptr), off(0) {}
 
   SrcSharedTensor(const Tensor *tensor, size_t offset) :
-    src(tensor), off(offset) {}
+    src(tensor),
+    off(offset) {}
 
   /**
    * @brief   Get the allocated src tensor
@@ -181,6 +182,12 @@ void Tensor::allocate() {
 #else
       throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
+    } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
+      mem_data = new MemoryData((void *)(new int8_t[dim.getDataLen()]{}));
+      data = std::shared_ptr<MemoryData>(mem_data, [](auto *mem_data) {
+        delete[] mem_data->template getAddr<int8_t>();
+        delete mem_data;
+      });
     }
     offset = 0;
     initialize();
@@ -227,6 +234,17 @@ bool Tensor::operator==(const Tensor &rhs) const {
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
+  } else if (dim.getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    const int8_t *_data = getData<int8_t>();
+    const int8_t *_rdata = rhs.getData<int8_t>();
+    for (size_t i = 0; i < len; ++i) {
+      /** not checking sign change is intentional to avoid float calculation
+       * errors around 0 */
+      if ((std::isnan(_data[i]) && !std::isnan(_rdata[i])) ||
+          (!std::isnan(_data[i]) && std::isnan(_rdata[i])) ||
+          _data[i] != _rdata[i])
+        return false;
+    }
   }
 
   return true;
@@ -243,6 +261,8 @@ void Tensor::setRandNormal(float mean, float std) {
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
+  } else if (this->getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    throw std::invalid_argument("Error: RandNormal is invalid for QINT8");
   }
 }
 
@@ -257,6 +277,8 @@ void Tensor::setRandUniform(float min, float max) {
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
+  } else if (this->getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    throw std::invalid_argument("Error: RandUniform is invalid for QINT8");
   }
 }
 
@@ -271,6 +293,8 @@ void Tensor::setRandBernoulli(float probability) {
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
+  } else if (this->getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    throw std::invalid_argument("Error: setRandBernoulli is invalid for QINT8");
   }
 }
 
@@ -2644,6 +2668,51 @@ void Tensor::print(std::ostream &out) const {
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
+  } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    const int8_t *data = getData<int8_t>();
+    unsigned int len = size();
+    out << "data addr: " << reinterpret_cast<const float *>(data) << '\n';
+    out << dim;
+
+    if (len > 100) {
+      out << '[' << (int)data[0] << ' ' << (int)data[1] << ' ' << (int)data[2]
+          << " ... " << (int)data[len - 3] << ' ' << (int)data[len - 2] << ' '
+          << (int)data[len - 1] << ']' << std::endl;
+      return;
+    }
+
+    std::ios init(NULL);
+    init.copyfmt(out);
+    if (getFormat() == Tformat::NCHW) {
+      for (unsigned int k = 0; k < batch(); k++) {
+        for (unsigned int l = 0; l < channel(); l++) {
+          for (unsigned int i = 0; i < height(); i++) {
+            for (unsigned int j = 0; j < width(); j++) {
+              out << std::setw(10) << (int)this->getValue<int8_t>(k, l, i, j)
+                  << " ";
+            }
+            out << std::endl;
+          }
+          out << std::endl;
+        }
+        out << "-------" << std::endl;
+      }
+    } else {
+      for (unsigned int k = 0; k < batch(); k++) {
+        for (unsigned int i = 0; i < height(); i++) {
+          for (unsigned int j = 0; j < width(); j++) {
+            for (unsigned int l = 0; l < channel(); l++) {
+              out << std::setw(10) << (int)this->getValue<int8_t>(k, l, i, j)
+                  << " ";
+            }
+            out << std::endl;
+          }
+          out << std::endl;
+        }
+        out << "-------" << std::endl;
+      }
+      out.copyfmt(init);
+    }
   }
 }
 
@@ -2742,7 +2811,11 @@ void Tensor::copy(const void *buf) {
   NNTR_THROW_IF(!contiguous, std::invalid_argument)
     << getName() << "Tensor is not contiguous, cannot copy.";
 
-  if (getDataType() == ml::train::TensorDim::DataType::FP16) {
+  if (getDataType() == ml::train::TensorDim::DataType::FP32) {
+    if (buf == getData()) {
+      return;
+    }
+  } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
     if (buf == getData<_FP16>()) {
       return;
@@ -2750,23 +2823,24 @@ void Tensor::copy(const void *buf) {
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
-  } else if (getDataType() == ml::train::TensorDim::DataType::FP32) {
-    if (buf == getData()) {
+  } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    if (buf == getData<int8_t>()) {
       return;
     }
   }
-  // std::string type_ =
-  //   (getDataType() == ml::train::TensorDim::DataType::FP16) ? "FP16" : "NO";
-  // std::cout << type_ << std::endl;
 
-  if (getDataType() == ml::train::TensorDim::DataType::FP16) {
+  if (getDataType() == ml::train::TensorDim::DataType::FP32) {
+    scopy(size(), (float *)buf, 1, getData<float>(), 1);
+  } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
     scopy(size(), (_FP16 *)buf, 1, getData<_FP16>(), 1);
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
-  } else if (getDataType() == ml::train::TensorDim::DataType::FP32) {
-    scopy(size(), (float *)buf, 1, getData<float>(), 1);
+  } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    for (unsigned int i = 0; i < size(); ++i) {
+      getData<int8_t>()[i] = ((int8_t *)buf)[i];
+    }
   }
 }
 
@@ -3027,6 +3101,9 @@ void Tensor::setValue(float val) {
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
+  } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    int8_t *data = getData<int8_t>();
+    std::fill(data, data + size(), val);
   }
 }
 
@@ -3045,6 +3122,8 @@ void Tensor::setZero() {
 #else
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
+  } else if (dim.getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    apply_i<int8_t>([](int8_t val) -> int8_t { return 0; });
   }
 }
 
@@ -3343,4 +3422,69 @@ Tensor Tensor::rotate_180(Tensor in) {
   return output;
 }
 
+void Tensor::setScaleFactors(std::vector<float> scales) {
+  if (!scale_factors.empty()) {
+    throw std::invalid_argument("Error: scale factors already been set");
+  }
+
+  if (scales.size() != channel()) {
+    throw std::invalid_argument("Error: scale_factors.size() != channel() ");
+  }
+
+  scale_factors = scales;
+}
+
+std::vector<float> Tensor::getScaleFactors() { return scale_factors; }
+
+Tensor Tensor::dequantize(Tdatatype dtype) const {
+  Tensor t = Tensor(batch(), channel(), height(), width(), getFormat(), dtype);
+
+  return dequantize(t);
+}
+
+Tensor Tensor::dequantize(Tensor &output) const {
+  if (getDataType() == Tdatatype::FP32 || getDataType() == Tdatatype::FP16) {
+    throw std::invalid_argument("Error: Tensor cannot be dequantized");
+  }
+
+  if (output.getDataType() == Tdatatype::QINT8) {
+    throw std::invalid_argument("Error: Target datatype is quantized type");
+  }
+
+  if (getFormat() != output.getFormat())
+    throw std::invalid_argument("Error: TensorType do not match");
+
+  if (batch() != output.batch() || channel() != output.channel() ||
+      width() != output.width() || height() != output.height())
+    throw std::invalid_argument("Error: TensorDim do not match");
+
+  if (scale_factors.empty()) {
+    throw std::invalid_argument("Error: No scale factors");
+  }
+
+  for (unsigned int c = 0; c < channel(); ++c) {
+    for (unsigned int b = 0; b < batch(); ++b) {
+      for (unsigned int h = 0; h < height(); ++h) {
+        for (unsigned int w = 0; w < width(); ++w) {
+          if (output.getDataType() == Tdatatype::FP32) {
+            output.setValue(b, c, h, w,
+                            (float)getValue<int8_t>(b, c, h, w) *
+                              scale_factors[c]);
+          } else if (output.getDataType() == Tdatatype::FP16) {
+#ifdef ENABLE_FP16
+            output.setValue(b, c, h, w,
+                            (_FP16)getValue<int8_t>(b, c, h, w) *
+                              (_FP16)scale_factors[c]);
+#else
+            throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+          }
+        }
+      }
+    }
+  }
+
+  return output;
+} // namespace nntrainer
+
 } /* namespace nntrainer */
index 559d2a4..8acb451 100644 (file)
@@ -347,6 +347,86 @@ public:
 #endif
 
   /**
+   * @brief     Constructor of Tensor
+   * @param[in] d data for the Tensor. It needs to set format properly.
+   * @param[in] t_type Tensor type.
+   */
+  Tensor(std::vector<std::vector<std::vector<std::vector<int8_t>>>> const &d,
+         ml::train::TensorDim::TensorType t_type) {
+    if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
+      throw std::out_of_range(
+        "[Tensor] trying to initialize Tensor from empty vector");
+    }
+
+    if (t_type.data_type != Tdatatype::QINT8) {
+      throw std::out_of_range(
+        "[Tensor] TensorType do not match with input data type");
+    }
+
+    // if fm == Tformat::NCHW, then dim[0] == batch , dim[1] == channel, dim[2]
+    // == height, dim[3] == width. and if fm == Tformat::NHWC, dim[0] == batch,
+    // dim[1] == height, dim[2] == width, dim[3] == channel
+    dim.setTensorDim(0, d.size());
+    if (t_type.format == Tformat::NCHW) {
+      dim.setTensorDim(1, d[0].size());
+      dim.setTensorDim(2, d[0][0].size());
+      dim.setTensorDim(3, d[0][0][0].size());
+    } else {
+      dim.setTensorDim(2, d[0].size());
+      dim.setTensorDim(3, d[0][0].size());
+      dim.setTensorDim(1, d[0][0][0].size());
+    }
+
+    setTensorType(t_type);
+
+    strides = dim.computeStrides();
+
+    MemoryData *mem_data =
+      new MemoryData((void *)(new int8_t[dim.getDataLen()]()));
+    data = std::shared_ptr<MemoryData>(mem_data, [](MemoryData *mem_data) {
+      delete[] mem_data->getAddr<int8_t>();
+    });
+    offset = 0;
+    contiguous = true;
+    initializer = Initializer::NONE;
+
+    // if fm == Tformat::NCHW, then dim[0] == batch , dim[1] == channel, dim[2]
+    // == height, dim[3] == width. and if fm == Tformat::NHWC, dim[0] == batch,
+    // dim[1] == height, dim[2] == width, dim[3] == channel
+    if (t_type.format == Tformat::NCHW) {
+      for (unsigned int i = 0; i < batch(); ++i)
+        for (unsigned int j = 0; j < channel(); ++j)
+          for (unsigned int k = 0; k < height(); ++k)
+            for (unsigned int l = 0; l < width(); ++l)
+              this->setValue(i, j, k, l, d[i][j][k][l]);
+    } else {
+      for (unsigned int i = 0; i < batch(); ++i)
+        for (unsigned int j = 0; j < height(); ++j)
+          for (unsigned int k = 0; k < width(); ++k)
+            for (unsigned int l = 0; l < channel(); ++l)
+              this->setValue(i, l, j, k, d[i][j][k][l]);
+    }
+  };
+
+  /**
+   * @brief     Constructor of Tensor
+   * @note      This constructor copies vector again. needs refactoring
+   * @param[in] d data for the Tensor. It needs to set format properly.
+   */
+  Tensor(std::vector<std::vector<std::vector<int8_t>>> const &d,
+         ml::train::TensorDim::TensorType t_type) :
+    Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
+
+  /**
+   * @brief     Constructor of Tensor
+   * @note      This constructor copies vector again. needs refactoring
+   * @param[in] d data for the Tensor with batch size one
+   */
+  Tensor(std::vector<std::vector<int8_t>> const &d,
+         ml::train::TensorDim::TensorType t_type) :
+    Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
+
+  /**
    *  @brief  Copy constructor of Tensor.
    *  @param[in] Tensor &
    */
@@ -1292,6 +1372,8 @@ public:
 #else
       ml_loge("%s", "Error: enable-fp16 is not enabled");
 #endif
+    } else if (getDataType() == Tdatatype::QINT8) {
+      getData<int8_t>()[getIndex(batch, c, h, w)] = value;
     }
   }
 
@@ -1317,6 +1399,9 @@ public:
 #else
       ml_loge("%s", "Error: enable-fp16 is not enabled");
 #endif
+    } else if (getDataType() == Tdatatype::QINT8) {
+      getData<int8_t>()[idx] *= beta;
+      getData<int8_t>()[idx] += value;
     }
   }
 
@@ -1789,6 +1874,32 @@ public:
    */
   Tdatatype getDataType() const { return dim.getDataType(); }
 
+  /**
+   * @brief     Set scale factors of the tensor
+   * @param[in] scales scale factors
+   */
+  void setScaleFactors(std::vector<float> scales);
+
+  /**
+   * @brief     Get scale factors of the tensor
+   * @retval    scales scale factors
+   */
+  std::vector<float> getScaleFactors();
+
+  /**
+   * @brief     Dequantize Tensor to dtype
+   * @param[in] dtype Target Tensor DataType
+   * @retval    Dequantized Tensor
+   */
+  Tensor dequantize(Tdatatype dtype) const;
+
+  /**
+   * @brief      Dequantize Tensor to output tensor datatype
+   * @param[out] output Tensor to store the result
+   * @retval     Dequantized Tensor
+   */
+  Tensor dequantize(Tensor &output) const;
+
   static constexpr float epsilon = 1e-5;
 
 private:
@@ -1800,6 +1911,7 @@ private:
   std::string name; /**< name of the tensor */
   std::shared_ptr<MemoryData> data;
   size_t offset;
+  std::vector<float> scale_factors;
 
   /**<
    * When using shared_data with tensor, this stores the ptr of the source
index b99e18f..5bb65de 100644 (file)
@@ -33,7 +33,9 @@ TensorDim::TensorDim(TensorDim::Format fm, TensorDim::DataType d_type,
 TensorDim::TensorDim(TensorType t_type_,
                      const std::bitset<MAXDIM> &eff_dim_flag_,
                      const std::bitset<MAXDIM> &dyn_dim_flag_) :
-  t_type(t_type_), eff_dim_flag(eff_dim_flag_), dyn_dim_flag(dyn_dim_flag_) {
+  t_type(t_type_),
+  eff_dim_flag(eff_dim_flag_),
+  dyn_dim_flag(dyn_dim_flag_) {
   for (size_t i = 0; i < MAXDIM; ++i) {
     dim[i] = 0;
   }
@@ -121,6 +123,8 @@ uint TensorDim::getDataTypeSize() const {
 #endif
   case TensorDim::DataType::FP32:
     return sizeof(float);
+  case TensorDim::DataType::QINT8:
+    return sizeof(int8_t);
   default:
     return sizeof(float);
   }
@@ -333,8 +337,15 @@ bool TensorDim::is_dynamic() const { return dyn_dim_flag.any(); }
 
 std::ostream &operator<<(std::ostream &out, TensorDim const &d) {
 
-  std::string type_ =
-    (d.getDataType() == ml::train::TensorDim::DataType::FP16) ? "FP16" : "FP32";
+  std::string type_;
+  if (d.getDataType() == ml::train::TensorDim::DataType::FP32) {
+    type_ = "FP32";
+  } else if (d.getDataType() == ml::train::TensorDim::DataType::FP16) {
+    type_ = "FP16";
+  } else if (d.getDataType() == ml::train::TensorDim::DataType::QINT8) {
+    type_ = "QINT8";
+  }
+
   std::string format_ =
     (d.getFormat() == ml::train::TensorDim::Format::NCHW) ? "NCHW" : "NHWC";
   out << "Shape: " << d.batch() << ":" << d.channel() << ":" << d.height()
index 25662ae..9531860 100644 (file)
@@ -647,12 +647,15 @@ void from_string(const std::string &value, std::vector<T> &property) {
  */
 struct TensorDataTypeInfo {
   using Enum = nntrainer::TensorDim::DataType;
-  static constexpr std::initializer_list<Enum> EnumList = {Enum::FP16,
-                                                           Enum::FP32};
+  static constexpr std::initializer_list<Enum> EnumList = {
+    Enum::QINT8, Enum::FP16, Enum::FP32};
 
-  static constexpr const char *EnumStr[] = {"FP16", "FP32"};
+  static constexpr const char *EnumStr[] = {"QINT8", "FP16", "FP32"};
 };
 
+/**
+ * @brief     Enumeration of Format for model & layer
+ */
 struct TensorFormatInfo {
   using Enum = nntrainer::TensorDim::Format;
   static constexpr std::initializer_list<Enum> EnumList = {Enum::NCHW,
@@ -671,6 +674,12 @@ class TensorDataType final : public EnumProperty<TensorDataTypeInfo> {
 public:
   using prop_tag = enum_class_prop_tag;
   static constexpr const char *key = "tensor_type";
+
+  /**
+   * @brief Constructor
+   *
+   * @param value value to set, defaults to FP32
+   */
   TensorDataType(
     TensorDataTypeInfo::Enum value = TensorDataTypeInfo::Enum::FP32) {
     set(value);
@@ -684,13 +693,13 @@ public:
 class TensorFormat final : public EnumProperty<TensorFormatInfo> {
 public:
   static constexpr const char *key =
-    "tensor_format";             /**< unique key to access */
+    "tensor_format";                    /**< unique key to access */
   using prop_tag = enum_class_prop_tag; /**< property type */
 
   /**
    * @brief Constructor
    *
-   * @param value value to set, defaults to false
+   * @param value value to set, defaults to NCHW
    */
   TensorFormat(TensorFormatInfo::Enum value = TensorFormatInfo::Enum::NCHW) {
     set(value);
index 17ae168..0f4731b 100644 (file)
@@ -188,6 +188,34 @@ TEST(nntrainer_Tensor, Tensor_03_p) {
   EXPECT_EQ(status, ML_ERROR_NONE);
 }
 
+TEST(nntrainer_Tensor, Tensor_04_p) {
+  int status = ML_ERROR_NONE;
+  int batch = 3;
+  int height = 3;
+  int width = 10;
+  std::vector<std::vector<std::vector<int8_t>>> in;
+
+  for (int k = 0; k < batch; ++k) {
+    std::vector<std::vector<int8_t>> ttv;
+    for (int i = 0; i < height; ++i) {
+      std::vector<int8_t> tv;
+      for (int j = 0; j < width; ++j) {
+        tv.push_back(k * height * width + i * width + j);
+      }
+      ttv.push_back(tv);
+    }
+    in.push_back(ttv);
+  }
+
+  nntrainer::Tensor tensor = nntrainer::Tensor(
+    in, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
+  ASSERT_NE(nullptr, tensor.getData<int8_t>());
+
+  if (tensor.getValue<int8_t>(0, 0, 0, 1) != 1)
+    status = ML_ERROR_INVALID_PARAMETER;
+  EXPECT_EQ(status, ML_ERROR_NONE);
+}
+
 TEST(nntrainer_Tensor, multiply_i_01_p) {
   int status = ML_ERROR_NONE;
   int batch = 3;
@@ -4238,6 +4266,171 @@ TEST(nntrainer_Tensor, multiply_strided_06_p) {
   EXPECT_EQ(status, ML_ERROR_NONE);
 }
 
+/**
+ * @brief dequantize FP32 tensor
+ */
+TEST(nntrainer_Tensor, dequantize_01_n) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(batch, channel, height, width);
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+
+  nntrainer::Tensor output(batch, channel, height, width);
+
+  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+}
+
+/**
+ * @brief dequantize tensor with different dimension
+ */
+TEST(nntrainer_Tensor, dequantize_02_n) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch + 1, channel, height + 1, width + 1,
+    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+
+  nntrainer::Tensor output(batch, channel, height, width);
+
+  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+}
+
+/**
+ * @brief dequantize tensor with no scale factors
+ */
+TEST(nntrainer_Tensor, dequantize_03_n) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
+
+  nntrainer::Tensor output(batch, channel, height, width);
+
+  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+}
+
+/**
+ * @brief dequantize tensor with incorrect number of scale factors
+ */
+TEST(nntrainer_Tensor, dequantize_04_n) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
+  EXPECT_THROW(
+    {
+      input.setScaleFactors({2.0, 1.5, 1.0, 0.5});
+    },
+    std::invalid_argument);
+}
+
+/**
+ * @brief dequantize tensor to QINT8
+ */
+TEST(nntrainer_Tensor, dequantize_05_n) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+
+  nntrainer::Tensor output(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
+
+  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+}
+
+/**
+ * @brief dequantize qint8 tensor
+ */
+TEST(nntrainer_Tensor, dequantize_06_p) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k + 1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+
+  nntrainer::Tensor output;
+
+  EXPECT_NO_THROW({ output = input.dequantize(nntrainer::Tdatatype::FP32); });
+
+  float answer_data[] = {
+    1.5, 1.5, 1.5, 1.5, 1.5, 3,   3,   3,   3,   3,   4.5, 4.5, 4.5, 4.5, 4.5,
+    6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   7,   7,   7,   7,   7,
+    8,   8,   8,   8,   8,   9,   9,   9,   9,   9,   5.5, 5.5, 5.5, 5.5, 5.5,
+    6,   6,   6,   6,   6,   6.5, 6.5, 6.5, 6.5, 6.5, 7,   7,   7,   7,   7};
+
+  nntrainer::Tensor answer(ml::train::TensorDim(batch, channel, height, width,
+                                                {nntrainer::Tformat::NCHW,
+                                                 nntrainer::Tdatatype::FP32}),
+                           answer_data);
+
+  EXPECT_EQ(output, answer);
+}
+
+/**
+ * @brief dequantize tensor
+ */
+TEST(nntrainer_Tensor, dequantize_07_p) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k + 1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+
+  nntrainer::Tensor output(batch, channel, height, width);
+
+  EXPECT_NO_THROW({ input.dequantize(output); });
+
+  float answer_data[] = {
+    1.5, 1.5, 1.5, 1.5, 1.5, 3,   3,   3,   3,   3,   4.5, 4.5, 4.5, 4.5, 4.5,
+    6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   7,   7,   7,   7,   7,
+    8,   8,   8,   8,   8,   9,   9,   9,   9,   9,   5.5, 5.5, 5.5, 5.5, 5.5,
+    6,   6,   6,   6,   6,   6.5, 6.5, 6.5, 6.5, 6.5, 7,   7,   7,   7,   7};
+
+  nntrainer::Tensor answer(ml::train::TensorDim(batch, channel, height, width,
+                                                {nntrainer::Tformat::NCHW,
+                                                 nntrainer::Tdatatype::FP32}),
+                           answer_data);
+
+  EXPECT_EQ(output, answer);
+}
+
 int main(int argc, char **argv) {
   int result = -1;
 
index e444c7e..37bcd92 100644 (file)
@@ -213,7 +213,8 @@ TEST(nntrainer_Tensor, multiply_i_03_nhwc_n) {
   GEN_TEST_INPUT(input, i * (channel * height * width) + j * (width * channel) +
                           k * channel + l);
 
-  nntrainer::Tensor target2(batch, channel, height - 2, width - 1, NHWC_, FP32_);
+  nntrainer::Tensor target2(batch, channel, height - 2, width - 1, NHWC_,
+                            FP32_);
   status = input.multiply_i(target2);
 
   EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
@@ -670,7 +671,8 @@ TEST(nntrainer_Tensor, divide_i_02_nhwc_n) {
   GEN_TEST_INPUT_NHWC(input, i * (height * width * channel) +
                                j * (width * channel) + k * channel + l + 1);
 
-  nntrainer::Tensor original(batch, channel, height - 2, width - 1, NHWC_, FP32_);
+  nntrainer::Tensor original(batch, channel, height - 2, width - 1, NHWC_,
+                             FP32_);
 
   status = input.divide_i(original);
   EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
@@ -721,7 +723,8 @@ TEST(nntrainer_Tensor, divide_03_nhwc_n) {
   GEN_TEST_INPUT_NHWC(input, i * (height * width * channel) +
                                j * (width * channel) + k * channel + l + 1);
 
-  nntrainer::Tensor test(batch - 1, channel - 1, height - 1, width - 1, NHWC_, FP32_);
+  nntrainer::Tensor test(batch - 1, channel - 1, height - 1, width - 1, NHWC_,
+                         FP32_);
 
   EXPECT_THROW({ input.divide(test); }, std::invalid_argument);
 }
@@ -1453,7 +1456,8 @@ TEST(nntrainer_Tensor, add_03_nhwc_n) {
   GEN_TEST_INPUT_NHWC(input, i * (height * width * channel) +
                                j * (width * channel) + k * channel + 1);
 
-  nntrainer::Tensor test(batch - 1, height - 1, width - 1, channel, NHWC_, FP32_);
+  nntrainer::Tensor test(batch - 1, height - 1, width - 1, channel, NHWC_,
+                         FP32_);
 
   EXPECT_THROW({ input.add(test); }, std::invalid_argument);
 }
@@ -1604,7 +1608,8 @@ TEST(nntrainer_Tensor, subtract_i_03_nhwc_n) {
   GEN_TEST_INPUT_NHWC(target, i * (height * width * channel) +
                                 j * (width * channel) + k * channel + l);
 
-  nntrainer::Tensor target2(batch, height, width - 3, channel - 1, NHWC_, FP32_);
+  nntrainer::Tensor target2(batch, height, width - 3, channel - 1, NHWC_,
+                            FP32_);
 
   status = target.subtract_i(target2);
   EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
@@ -1663,7 +1668,8 @@ TEST(nntrainer_Tensor, subtract_03_nhwc_n) {
   GEN_TEST_INPUT_NHWC(input, i * (height * width * channel) +
                                j * (width * channel) + k * channel + 1);
 
-  nntrainer::Tensor test(batch - 1, channel - 1, height, width - 1, NHWC_, FP32_);
+  nntrainer::Tensor test(batch - 1, channel - 1, height, width - 1, NHWC_,
+                         FP32_);
 
   EXPECT_THROW({ input.subtract(test); }, std::invalid_argument);
 }
@@ -2528,7 +2534,7 @@ TEST(nntrainer_Tensor, average_nhwc_p) {
   EXPECT_EQ(actual, expected);
 
   int idx = 0;
-  t = t.apply((std::function<float (float)>)[&](float in) { return idx++ % 2; });
+  t = t.apply((std::function<float(float)>)[&](float in) { return idx++ % 2; });
 
   actual = t.average();
   expected = constant(0.5, 1, 1, 1, 1, NHWC_, FP32_);
@@ -3202,7 +3208,8 @@ TEST(nntrainer_Tensor, fill_p) {
   /// same dimension, buffer size
   {
     nntrainer::Tensor target(3, 2, 4, 5, NHWC_, FP32_);
-    nntrainer::Tensor original = randUniform(3, 2, 4, 5, -1.0f, 1.0f, NHWC_, FP32_);
+    nntrainer::Tensor original =
+      randUniform(3, 2, 4, 5, -1.0f, 1.0f, NHWC_, FP32_);
     target.fill(original, false);
 
     EXPECT_EQ(target, original);
@@ -3217,7 +3224,8 @@ TEST(nntrainer_Tensor, fill_p) {
   /// uninitialized with initialized flag is true
   {
     nntrainer::Tensor target;
-    nntrainer::Tensor original = randUniform(3, 2, 4, 5, -1.0f, 1.0f, NHWC_, FP32_);
+    nntrainer::Tensor original =
+      randUniform(3, 2, 4, 5, -1.0f, 1.0f, NHWC_, FP32_);
     target.fill(original, true);
 
     EXPECT_EQ(target, original);
@@ -3226,13 +3234,15 @@ TEST(nntrainer_Tensor, fill_p) {
 
 TEST(nntrainer_Tensor, fill_uninitialized_n) {
   nntrainer::Tensor target;
-  nntrainer::Tensor original = randUniform(3, 1, 2, 3, -1.0f, 1.0f, NHWC_, FP32_);
+  nntrainer::Tensor original =
+    randUniform(3, 1, 2, 3, -1.0f, 1.0f, NHWC_, FP32_);
   EXPECT_THROW(target.fill(original, false), std::invalid_argument);
 }
 
 TEST(nntrainer_Tensor, fill_different_dimension_n) {
   nntrainer::Tensor target(3, 1, 3, 2, NHWC_, FP32_);
-  nntrainer::Tensor original = randUniform(3, 1, 2, 3, -1.0f, 1.0f, NHWC_, FP32_);
+  nntrainer::Tensor original =
+    randUniform(3, 1, 2, 3, -1.0f, 1.0f, NHWC_, FP32_);
   EXPECT_THROW(target.fill(original, false), std::invalid_argument);
 }
 
@@ -3292,7 +3302,8 @@ TEST(nntrainer_Tensor, add_strided_02_nhwc_n) {
   GEN_TEST_INPUT_NHWC(input, i * (height * width * channel) +
                                j * (width * channel) + k * channel + 1);
 
-  nntrainer::Tensor test(batch - 1, height - 1, width - 1, channel, NHWC_, FP32_);
+  nntrainer::Tensor test(batch - 1, height - 1, width - 1, channel, NHWC_,
+                         FP32_);
 
   EXPECT_THROW({ input.add_strided(test); }, std::invalid_argument);
 }
@@ -4664,6 +4675,119 @@ TEST(nntrainer_Tensor, tranpose_dimension_not_match_nhwc_n) {
   EXPECT_THROW(a.transpose("0:1:2", b), std::invalid_argument);
 }
 
+/**
+ * @brief dequantize tensor with different format
+ */
+TEST(nntrainer_Tensor, dequantize_01_n) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+
+  nntrainer::Tensor output(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::FP32});
+
+  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+}
+
+/**
+ * @brief dequantize tensor with different format
+ */
+TEST(nntrainer_Tensor, dequantize_02_n) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+
+  nntrainer::Tensor output(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32});
+
+  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+}
+
+/**
+ * @brief dequantize nhwc tensor
+ */
+TEST(nntrainer_Tensor, dequantize_03_p) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k + 1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+
+  nntrainer::Tensor output;
+  output.getDim().setFormat(nntrainer::Tformat::NHWC);
+
+  EXPECT_NO_THROW({ output = input.dequantize(nntrainer::Tdatatype::FP32); });
+
+  float answer_data[] = {1.5, 6, 5.5, 1.5, 6, 5.5, 1.5, 6, 5.5, 1.5, 6, 5.5,
+                         1.5, 6, 5.5, 3,   7, 6,   3,   7, 6,   3,   7, 6,
+                         3,   7, 6,   3,   7, 6,   4.5, 8, 6.5, 4.5, 8, 6.5,
+                         4.5, 8, 6.5, 4.5, 8, 6.5, 4.5, 8, 6.5, 6,   9, 7,
+                         6,   9, 7,   6,   9, 7,   6,   9, 7,   6,   9, 7};
+
+  nntrainer::Tensor answer(ml::train::TensorDim(batch, channel, height, width,
+                                                {nntrainer::Tformat::NHWC,
+                                                 nntrainer::Tdatatype::FP32}),
+                           answer_data);
+
+  EXPECT_EQ(output, answer);
+}
+
+/**
+ * @brief dequantize nhwc tensor
+ */
+TEST(nntrainer_Tensor, dequantize_04_p) {
+  int batch = 1;
+  int channel = 3;
+  int height = 4;
+  int width = 5;
+
+  nntrainer::Tensor input(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::QINT8});
+  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k + 1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+
+  nntrainer::Tensor output(
+    batch, channel, height, width,
+    {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::FP32});
+
+  EXPECT_NO_THROW({ input.dequantize(output); });
+
+  float answer_data[] = {1.5, 6, 5.5, 1.5, 6, 5.5, 1.5, 6, 5.5, 1.5, 6, 5.5,
+                         1.5, 6, 5.5, 3,   7, 6,   3,   7, 6,   3,   7, 6,
+                         3,   7, 6,   3,   7, 6,   4.5, 8, 6.5, 4.5, 8, 6.5,
+                         4.5, 8, 6.5, 4.5, 8, 6.5, 4.5, 8, 6.5, 6,   9, 7,
+                         6,   9, 7,   6,   9, 7,   6,   9, 7,   6,   9, 7};
+
+  nntrainer::Tensor answer(ml::train::TensorDim(batch, channel, height, width,
+                                                {nntrainer::Tformat::NHWC,
+                                                 nntrainer::Tdatatype::FP32}),
+                           answer_data);
+
+  EXPECT_EQ(output, answer);
+}
+
 int main(int argc, char **argv) {
   int result = -1;
 
index 7b6e2a0..c898818 100644 (file)
@@ -17,6 +17,7 @@
 #include <gtest/gtest.h>
 
 #include <basic_planner.h>
+#include <optimized_v1_planner.h>
 #include <tensor_pool.h>
 
 constexpr unsigned int MEM_BYTES = 128;
@@ -435,6 +436,67 @@ TEST(TensorPool, validate_memory) {
 }
 
 /**
+ * @brief qint8 tensors reuse fp32 tensor memory space
+ */
+TEST(TensorPool, validate_memory_reuse_p) {
+  // |--------- t1 ---------|
+  // |-t2-||-t3-||-t4-||-t5-|
+  nntrainer::TensorPool pool;
+  nntrainer::Tensor *t1 = nullptr, *t2 = nullptr, *t3 = nullptr, *t4 = nullptr,
+                    *t5 = nullptr;
+
+  EXPECT_NO_THROW(
+    t1 = pool.request("t1", nntrainer::TensorDim({4}), {0},
+                      nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN));
+  EXPECT_NE(t1, nullptr);
+  EXPECT_FALSE(t1->isAllocated());
+
+  EXPECT_NO_THROW(
+    t2 = pool.request("t2",
+                      nntrainer::TensorDim({4}, {nntrainer::Tformat::NCHW,
+                                                 nntrainer::Tdatatype::QINT8}),
+                      {1}, nntrainer::TensorLifespan::BACKWARD_FUNC_LIFESPAN));
+  EXPECT_NE(t2, nullptr);
+  EXPECT_FALSE(t2->isAllocated());
+
+  EXPECT_NO_THROW(
+    t3 = pool.request("t3",
+                      nntrainer::TensorDim({4}, {nntrainer::Tformat::NCHW,
+                                                 nntrainer::Tdatatype::QINT8}),
+                      {1}, nntrainer::TensorLifespan::BACKWARD_FUNC_LIFESPAN));
+  EXPECT_NE(t3, nullptr);
+  EXPECT_FALSE(t3->isAllocated());
+
+  EXPECT_NO_THROW(
+    t4 = pool.request("t4",
+                      nntrainer::TensorDim({4}, {nntrainer::Tformat::NCHW,
+                                                 nntrainer::Tdatatype::QINT8}),
+                      {1}, nntrainer::TensorLifespan::BACKWARD_FUNC_LIFESPAN));
+  EXPECT_NE(t4, nullptr);
+  EXPECT_FALSE(t4->isAllocated());
+
+  EXPECT_NO_THROW(
+    t5 = pool.request("t5",
+                      nntrainer::TensorDim({4}, {nntrainer::Tformat::NCHW,
+                                                 nntrainer::Tdatatype::QINT8}),
+                      {1}, nntrainer::TensorLifespan::BACKWARD_FUNC_LIFESPAN));
+  EXPECT_NE(t5, nullptr);
+  EXPECT_FALSE(t5->isAllocated());
+
+  EXPECT_NO_THROW(pool.finalize(nntrainer::OptimizedV1Planner(), 0, 2));
+  EXPECT_EQ(pool.minMemoryRequirement(), t1->bytes());
+
+  EXPECT_NO_THROW(pool.allocate());
+
+  EXPECT_EQ(t1->getAddress<float>(0), (float *)t2->getAddress<int8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(1), (float *)t3->getAddress<int8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(2), (float *)t4->getAddress<int8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(3), (float *)t5->getAddress<int8_t>(0));
+
+  EXPECT_NO_THROW(pool.deallocate());
+}
+
+/**
  * @brief check if data span of two tensor testOverlap
  *
  * @param t1 tensor1
@@ -601,7 +663,8 @@ TEST(TensorPool, view_of_placeholder_p) {
   /// t2        : 0 1 2 3 4 5 6 7 8 9
   /// t3        :     2 3
   nntrainer::Tensor t_original(t1->getDim());
-  t_original.apply_i((std::function<float (float)>)[i = 0u](float _) mutable { return ++i; });
+  t_original.apply_i(
+    (std::function<float(float)>)[i = 0u](float _) mutable { return ++i; });
   pool.fillPlaceholder("t1", t_original);
 
   testSubset(t1, &t_original);