[Tensor] Unsigned Quantized Tensor
authorDonghyeon Jeong <dhyeon.jeong@samsung.com>
Tue, 12 Sep 2023 00:09:35 +0000 (09:09 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 12 Sep 2023 04:58:02 +0000 (13:58 +0900)
- Quantized tensor values are unsigned with zero points
- Layer context dequantize quantized tensor when request weight
- Template dequantize function

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
nntrainer/models/model_common_properties.h
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h
nntrainer/utils/base_properties.h
test/unittest/unittest_nntrainer_tensor.cpp
test/unittest/unittest_nntrainer_tensor_fp16.cpp
test/unittest/unittest_nntrainer_tensor_nhwc.cpp
test/unittest/unittest_nntrainer_tensor_pool.cpp

index 620d7dd6abd592f515b96fd08d6420c54ed00a4c..e54600e20397fa90ac3811699d24870fe3e162b4 100644 (file)
@@ -20,7 +20,6 @@
 #include <nntrainer_log.h>
 #include <stdexcept>
 #include <var_grad.h>
-#include <weight.h>
 
 namespace nntrainer {
 
@@ -155,16 +154,6 @@ RunLayerContext::RunLayerContext(const std::string &name, bool trainable,
     throw std::invalid_argument("Creating invalid run context");
 }
 
-/**
- * @brief Get the Weight tensor object
- *
- * @param idx Identifier of the weight
- * @return Tensor& Reference to the weight tensor
- */
-Tensor &RunLayerContext::getWeight(unsigned int idx) const {
-  return weights[idx]->getVariableRef();
-}
-
 /**
  * @brief Get the Weight Gradient tensor object
  *
index f268489b0d14a86e1f89086df390a33503cbb1f6..ad7149fec53b07a762a2da6bdd80da7bacf5efb1 100644 (file)
 #include <tensor.h>
 #include <tensor_dim.h>
 #include <tensor_wrap_specs.h>
+#include <weight.h>
 
 namespace nntrainer {
 
-class Weight;
 class Var_Grad;
 
 /**
@@ -112,6 +112,11 @@ public:
    */
   const std::vector<TensorDim> &getInputDimensions() const { return input_dim; }
 
+  /**
+   * @brief Set Data Type for Input Dimensions
+   *
+   * @param ty data type to set
+   */
   void setInputDataType(TensorDim::DataType ty) {
     for (auto d : input_dim)
       d.setDataType(ty);
@@ -390,7 +395,21 @@ public:
    * @param idx Identifier of the weight
    * @return Tensor& Reference to the weight tensor
    */
-  Tensor &getWeight(unsigned int idx) const;
+  template <typename T = float> Tensor &getWeight(unsigned int idx) const {
+    if (weights[idx]->getDim().getDataType() == nntrainer::Tdatatype::QINT4 ||
+        weights[idx]->getDim().getDataType() == nntrainer::Tdatatype::QINT8) {
+      Tensor output(weights[idx]->getDim());
+
+      if (sizeof(T) == sizeof(float)) {
+        output.setDataType(nntrainer::Tdatatype::FP32);
+      } else {
+        output.setDataType(nntrainer::Tdatatype::FP16);
+      }
+
+      return weights[idx]->getVariableRef().dequantize<T>(output);
+    }
+    return weights[idx]->getVariableRef();
+  }
 
   /**
    * @brief Get the Weight Gradient tensor object
index ea8f8f16dd972e797026e906ac12bc3bb46e337c..79aa7ce4fc5c5c0598936f67d52ee93d2ea29bfa 100644 (file)
@@ -183,12 +183,14 @@ public:
  * @brief     Enumeration of Data Type for model & layer
  */
 struct ModelTensorDataTypeInfo {
-  enum Enum { W16A16, W16A32, W32A16, W32A32 };
+  enum Enum { W4A16, W4A32, W8A16, W8A32, W16A16, W16A32, W32A16, W32A32 };
   static constexpr std::initializer_list<Enum> EnumList = {
+    Enum::W4A16,  Enum::W4A32,  Enum::W8A16,  Enum::W8A32,
     Enum::W16A16, Enum::W16A32, Enum::W32A16, Enum::W32A32};
 
-  static constexpr const char *EnumStr[] = {"FP16-FP16", "FP16-FP32",
-                                            "FP32-FP16", "FP32-FP32"};
+  static constexpr const char *EnumStr[] = {
+    "QINT4-FP16", "QINT4-FP32", "QINT8-FP16", "QINT8-FP32",
+    "FP16-FP16",  "FP16-FP32",  "FP32-FP16",  "FP32-FP32"};
 };
 
 /**
@@ -199,6 +201,12 @@ class ModelTensorDataType final : public EnumProperty<ModelTensorDataTypeInfo> {
 public:
   using prop_tag = enum_class_prop_tag;
   static constexpr const char *key = "model_tensor_type";
+
+  /**
+   * @brief Constructor
+   *
+   * @param value value to set, defaults to W32A32
+   */
   ModelTensorDataType(ModelTensorDataTypeInfo::Enum value =
                         ModelTensorDataTypeInfo::Enum::W32A32) {
     set(value);
index dd6597e8370e221f5b0e1d24cfecc34d9c982312..d4989db64ef3a787da952486158fb53f7d85e9b7 100644 (file)
@@ -183,16 +183,16 @@ void Tensor::allocate() {
       throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
     } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
-      mem_data = new MemoryData((void *)(new int8_t[dim.getDataLen()]{}));
+      mem_data = new MemoryData((void *)(new uint8_t[dim.getDataLen()]{}));
       data = std::shared_ptr<MemoryData>(mem_data, [](auto *mem_data) {
-        delete[] mem_data->template getAddr<int8_t>();
+        delete[] mem_data->template getAddr<uint8_t>();
         delete mem_data;
       });
     } else if (getDataType() == ml::train::TensorDim::DataType::QINT4) {
       mem_data =
-        new MemoryData((void *)(new int8_t[(dim.getDataLen() + 1) / 2]{}));
+        new MemoryData((void *)(new uint8_t[(dim.getDataLen() + 1) / 2]{}));
       data = std::shared_ptr<MemoryData>(mem_data, [](auto *mem_data) {
-        delete[] mem_data->template getAddr<int8_t>();
+        delete[] mem_data->template getAddr<uint8_t>();
         delete mem_data;
       });
     }
@@ -216,6 +216,12 @@ bool Tensor::operator==(const Tensor &rhs) const {
   if (strides != rhs.strides)
     return false;
 
+  if (getScaleFactors() != rhs.getScaleFactors())
+    return false;
+
+  if (getZeroPoints() != rhs.getZeroPoints())
+    return false;
+
   if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
     const float *_data = getData<float>();
     const float *_rdata = rhs.getData<float>();
@@ -242,8 +248,8 @@ bool Tensor::operator==(const Tensor &rhs) const {
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::QINT8) {
-    const int8_t *_data = getData<int8_t>();
-    const int8_t *_rdata = rhs.getData<int8_t>();
+    const uint8_t *_data = getData<uint8_t>();
+    const uint8_t *_rdata = rhs.getData<uint8_t>();
     for (size_t i = 0; i < len; ++i) {
       /** not checking sign change is intentional to avoid float calculation
        * errors around 0 */
@@ -253,9 +259,9 @@ bool Tensor::operator==(const Tensor &rhs) const {
         return false;
     }
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::QINT4) {
-    const int8_t *_data = getData<int8_t>();
-    const int8_t *_rdata = rhs.getData<int8_t>();
-    int8_t data, rdata;
+    const uint8_t *_data = getData<uint8_t>();
+    const uint8_t *_rdata = rhs.getData<uint8_t>();
+    uint8_t data, rdata;
     for (size_t i = 0; i < len; ++i) {
       /** not checking sign change is intentional to avoid float calculation
        * errors around 0 */
@@ -2696,7 +2702,7 @@ void Tensor::print(std::ostream &out) const {
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
   } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
-    const int8_t *data = getData<int8_t>();
+    const uint8_t *data = getData<uint8_t>();
     unsigned int len = size();
     out << "data addr: " << reinterpret_cast<const float *>(data) << '\n';
     out << dim;
@@ -2715,7 +2721,7 @@ void Tensor::print(std::ostream &out) const {
         for (unsigned int l = 0; l < channel(); l++) {
           for (unsigned int i = 0; i < height(); i++) {
             for (unsigned int j = 0; j < width(); j++) {
-              out << std::setw(10) << (int)this->getValue<int8_t>(k, l, i, j)
+              out << std::setw(10) << (int)this->getValue<uint8_t>(k, l, i, j)
                   << " ";
             }
             out << std::endl;
@@ -2729,7 +2735,7 @@ void Tensor::print(std::ostream &out) const {
         for (unsigned int i = 0; i < height(); i++) {
           for (unsigned int j = 0; j < width(); j++) {
             for (unsigned int l = 0; l < channel(); l++) {
-              out << std::setw(10) << (int)this->getValue<int8_t>(k, l, i, j)
+              out << std::setw(10) << (int)this->getValue<uint8_t>(k, l, i, j)
                   << " ";
             }
             out << std::endl;
@@ -2741,7 +2747,7 @@ void Tensor::print(std::ostream &out) const {
       out.copyfmt(init);
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::QINT4) {
-    const int8_t *data = getData<int8_t>();
+    const uint8_t *data = getData<uint8_t>();
     unsigned int len = size();
     out << "data addr: " << (float *)data << '\n';
     out << dim;
@@ -2896,11 +2902,11 @@ void Tensor::copy(const void *buf) {
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
   } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
-    if (buf == getData<int8_t>()) {
+    if (buf == getData<uint8_t>()) {
       return;
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::QINT4) {
-    if (buf == getData<int8_t>()) {
+    if (buf == getData<uint8_t>()) {
       return;
     }
   }
@@ -2915,11 +2921,11 @@ void Tensor::copy(const void *buf) {
 #endif
   } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
     for (unsigned int i = 0; i < size(); ++i) {
-      getData<int8_t>()[i] = ((int8_t *)buf)[i];
+      getData<uint8_t>()[i] = ((uint8_t *)buf)[i];
     }
   } else if (getDataType() == ml::train::TensorDim::DataType::QINT4) {
     for (unsigned int i = 0; i < (size() + 1) / 2; ++i) {
-      getData<int8_t>()[i] = ((int8_t *)buf)[i];
+      getData<uint8_t>()[i] = ((uint8_t *)buf)[i];
     }
   }
 }
@@ -3182,11 +3188,11 @@ void Tensor::setValue(float val) {
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
   } else if (getDataType() == ml::train::TensorDim::DataType::QINT8) {
-    int8_t *data = getData<int8_t>();
+    uint8_t *data = getData<uint8_t>();
     std::fill(data, data + size(), val);
   } else if (getDataType() == ml::train::TensorDim::DataType::QINT4) {
-    int8_t *data = getData<int8_t>();
-    int8_t mixed = encode_qint(val, val);
+    uint8_t *data = getData<uint8_t>();
+    uint8_t mixed = encode_qint(val, val);
     std::fill(data, data + (size() + 1) / 2, mixed);
   }
 }
@@ -3207,9 +3213,9 @@ void Tensor::setZero() {
     throw std::invalid_argument("Error: enable-fp16 is not enabled");
 #endif
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::QINT8) {
-    apply_i<int8_t>([](int8_t val) -> int8_t { return 0; });
+    apply_i<uint8_t>([](uint8_t val) -> uint8_t { return 0; });
   } else if (dim.getDataType() == ml::train::TensorDim::DataType::QINT4) {
-    apply_i<int8_t>([](int8_t val) -> int8_t { return 0; });
+    setValue(0);
   }
 }
 
@@ -3508,11 +3514,11 @@ Tensor Tensor::rotate_180(Tensor in) {
   return output;
 }
 
-int8_t Tensor::encode_qint(int8_t high, int8_t low) const {
+uint8_t Tensor::encode_qint(uint8_t high, uint8_t low) const {
   return (high << 4) | (low & 0x0f);
 };
 
-int8_t Tensor::decode_qint(int8_t val, bool isHigh) const {
+uint8_t Tensor::decode_qint(uint8_t val, bool isHigh) const {
   if (isHigh) {
     val = val >> 4;
   } else {
@@ -3523,105 +3529,67 @@ int8_t Tensor::decode_qint(int8_t val, bool isHigh) const {
   return val;
 }
 
-void Tensor::setScaleFactors(std::vector<float> scales, int idx) {
-  if (scales.empty() || idx < 0 || idx > 3) {
+void Tensor::setOutputAxis(int axis) {
+  if (axis < 0 || axis > 3) {
+    throw std::invalid_argument("Error: invalid parameter");
+  }
+  output_axis = axis;
+}
+
+int Tensor::getOutputAxis() const { return output_axis; }
+
+void Tensor::setScaleFactors(std::vector<float> scales) {
+  if (scales.empty()) {
     throw std::invalid_argument("Error: invalid parameter");
   }
 
-  if (idx == 0 && scales.size() != batch()) {
+  if (output_axis == 0 && scales.size() != batch()) {
     throw std::invalid_argument("Error: scale_factors.size() != batch() ");
   }
 
-  if (idx == 1 && scales.size() != channel()) {
+  if (output_axis == 1 && scales.size() != channel()) {
     throw std::invalid_argument("Error: scale_factors.size() != channel() ");
   }
 
-  if (idx == 2 && scales.size() != height()) {
+  if (output_axis == 2 && scales.size() != height()) {
     throw std::invalid_argument("Error: scale_factors.size() != height() ");
   }
 
-  if (idx == 3 && scales.size() != width()) {
+  if (output_axis == 3 && scales.size() != width()) {
     throw std::invalid_argument("Error: scale_factors.size() != width() ");
   }
 
   scale_factors = scales;
-  scale_idx = idx;
 }
 
-std::vector<float> Tensor::getScaleFactors() { return scale_factors; }
+std::vector<float> Tensor::getScaleFactors() const { return scale_factors; }
 
-Tensor Tensor::dequantize(Tdatatype dtype) const {
-  Tensor t = Tensor(batch(), channel(), height(), width(), getFormat(), dtype);
-
-  return dequantize(t);
-}
-
-Tensor Tensor::dequantize(Tensor &output) const {
-  if (getDataType() == Tdatatype::FP32 || getDataType() == Tdatatype::FP16) {
-    throw std::invalid_argument("Error: Tensor cannot be dequantized");
+void Tensor::setZeroPoints(std::vector<uint8_t> zp) {
+  if (zp.empty()) {
+    throw std::invalid_argument("Error: invalid parameter");
   }
 
-  if (output.getDataType() == Tdatatype::QINT8 ||
-      output.getDataType() == Tdatatype::QINT4) {
-    throw std::invalid_argument("Error: Target datatype is quantized type");
+  if (output_axis == 0 && zp.size() != batch()) {
+    throw std::invalid_argument("Error: zero_points.size() != batch() ");
   }
 
-  if (getFormat() != output.getFormat())
-    throw std::invalid_argument("Error: TensorType do not match");
+  if (output_axis == 1 && zp.size() != channel()) {
+    throw std::invalid_argument("Error: zero_points.size() != channel() ");
+  }
 
-  if (batch() != output.batch() || channel() != output.channel() ||
-      width() != output.width() || height() != output.height())
-    throw std::invalid_argument("Error: TensorDim do not match");
+  if (output_axis == 2 && zp.size() != height()) {
+    throw std::invalid_argument("Error: zero_points.size() != height() ");
+  }
 
-  if (scale_factors.empty()) {
-    throw std::invalid_argument("Error: No scale factors");
+  if (output_axis == 3 && zp.size() != width()) {
+    throw std::invalid_argument("Error: zero_points.size() != width() ");
   }
 
-  int idx;
-  for (unsigned int b = 0; b < batch(); ++b) {
-    for (unsigned int c = 0; c < channel(); ++c) {
-      for (unsigned int h = 0; h < height(); ++h) {
-        for (unsigned int w = 0; w < width(); ++w) {
-          if (scale_idx == 0)
-            idx = b;
-          else if (scale_idx == 1)
-            idx = c;
-          else if (scale_idx == 2)
-            idx = h;
-          else if (scale_idx == 3)
-            idx = w;
+  zero_points = zp;
+}
 
-          if (output.getDataType() == Tdatatype::FP32) {
-            if (getDataType() == Tdatatype::QINT8) {
-              output.setValue(b, c, h, w,
-                              (float)getValue<int8_t>(b, c, h, w) *
-                                scale_factors[idx]);
-            } else {
-              output.setValue(b, c, h, w,
-                              (float)getValueQint4(b, c, h, w) *
-                                scale_factors[idx]);
-            }
-          } else if (output.getDataType() == Tdatatype::FP16) {
-#ifdef ENABLE_FP16
-            if (getDataType() == Tdatatype::QINT8) {
-              output.setValue(b, c, h, w,
-                              (_FP16)getValue<int8_t>(b, c, h, w) *
-                                (_FP16)scale_factors[idx]);
-            } else {
-              output.setValue(b, c, h, w,
-                              (_FP16)getValueQint4(b, c, h, w) *
-                                (_FP16)scale_factors[idx]);
-            }
-#else
-            throw std::invalid_argument("Error: enable-fp16 is not enabled");
-#endif
-          }
-        }
-      }
-    }
-  }
+std::vector<uint8_t> Tensor::getZeroPoints() const { return zero_points; }
 
-  return output;
-} // namespace nntrainer
+// namespace nntrainer
 
 } /* namespace nntrainer */
index 991141797be53f94e39457a7c585d6d208e76d72..610e4c77955d6bfbe7a0cc424222144ddbe84d7a 100644 (file)
@@ -351,7 +351,7 @@ public:
    * @param[in] d data for the Tensor. It needs to set format properly.
    * @param[in] t_type Tensor type.
    */
-  Tensor(std::vector<std::vector<std::vector<std::vector<int8_t>>>> const &d,
+  Tensor(std::vector<std::vector<std::vector<std::vector<uint8_t>>>> const &d,
          ml::train::TensorDim::TensorType t_type) {
     if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
       throw std::out_of_range(
@@ -384,10 +384,10 @@ public:
 
     MemoryData *mem_data =
       (t_type.data_type == Tdatatype::QINT8)
-        ? new MemoryData((void *)(new int8_t[dim.getDataLen()]()))
-        : new MemoryData((void *)(new int8_t[(dim.getDataLen() + 1) / 2]()));
+        ? new MemoryData((void *)(new uint8_t[dim.getDataLen()]()))
+        : new MemoryData((void *)(new uint8_t[(dim.getDataLen() + 1) / 2]()));
     data = std::shared_ptr<MemoryData>(mem_data, [](MemoryData *mem_data) {
-      delete[] mem_data->getAddr<int8_t>();
+      delete[] mem_data->getAddr<uint8_t>();
     });
     offset = 0;
     contiguous = true;
@@ -416,7 +416,7 @@ public:
    * @note      This constructor copies vector again. needs refactoring
    * @param[in] d data for the Tensor. It needs to set format properly.
    */
-  Tensor(std::vector<std::vector<std::vector<int8_t>>> const &d,
+  Tensor(std::vector<std::vector<std::vector<uint8_t>>> const &d,
          ml::train::TensorDim::TensorType t_type) :
     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
 
@@ -425,7 +425,7 @@ public:
    * @note      This constructor copies vector again. needs refactoring
    * @param[in] d data for the Tensor with batch size one
    */
-  Tensor(std::vector<std::vector<int8_t>> const &d,
+  Tensor(std::vector<std::vector<uint8_t>> const &d,
          ml::train::TensorDim::TensorType t_type) :
     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
 
@@ -575,8 +575,8 @@ public:
    * @param[in] idx location
    * @retval    qint4 value in location
    */
-  int8_t getValueQint4(unsigned int idx) const noexcept {
-    int8_t value = getData<int8_t>()[idx / 2];
+  uint8_t getValueQint4(unsigned int idx) const noexcept {
+    uint8_t value = getData<uint8_t>()[idx / 2];
     return decode_qint(value, (idx % 2 == 0));
   }
 
@@ -585,8 +585,8 @@ public:
    * @param[in] idx location
    * @retval    qint4 value in location
    */
-  int8_t getValueQint4(unsigned int idx) noexcept {
-    int8_t value = getData<int8_t>()[idx / 2];
+  uint8_t getValueQint4(unsigned int idx) noexcept {
+    uint8_t value = getData<uint8_t>()[idx / 2];
     return decode_qint(value, (idx % 2 == 0));
   }
 
@@ -598,10 +598,10 @@ public:
    * @param[in] w width location
    * @retval    qint4 value in location
    */
-  int8_t getValueQint4(unsigned int b, unsigned int c, unsigned int h,
-                       unsigned int w) const noexcept {
+  uint8_t getValueQint4(unsigned int b, unsigned int c, unsigned int h,
+                        unsigned int w) const noexcept {
     size_t idx = getIndex(b, c, h, w);
-    int8_t value = getData<int8_t>()[idx / 2];
+    uint8_t value = getData<uint8_t>()[idx / 2];
     return decode_qint(value, (idx % 2 == 0));
   }
 
@@ -613,10 +613,10 @@ public:
    * @param[in] w width location
    * @retval    qint4 value in location
    */
-  int8_t getValueQint4(unsigned int b, unsigned int c, unsigned int h,
-                       unsigned int w) noexcept {
+  uint8_t getValueQint4(unsigned int b, unsigned int c, unsigned int h,
+                        unsigned int w) noexcept {
     size_t idx = getIndex(b, c, h, w);
-    int8_t value = getData<int8_t>()[idx / 2];
+    uint8_t value = getData<uint8_t>()[idx / 2];
     return decode_qint(value, (idx % 2 == 0));
   }
 
@@ -1442,16 +1442,16 @@ public:
       ml_loge("%s", "Error: enable-fp16 is not enabled");
 #endif
     } else if (getDataType() == Tdatatype::QINT8) {
-      getData<int8_t>()[getIndex(batch, c, h, w)] = value;
+      getData<uint8_t>()[getIndex(batch, c, h, w)] = value;
     } else if (getDataType() == Tdatatype::QINT4) {
       int idx = getIndex(batch, c, h, w);
 
       if (idx % 2 == 0) {
-        getData<int8_t>()[idx / 2] =
-          encode_qint(value, getData<int8_t>()[idx / 2]);
+        getData<uint8_t>()[idx / 2] =
+          encode_qint(value, getData<uint8_t>()[idx / 2]);
       } else {
-        getData<int8_t>()[idx / 2] =
-          encode_qint(getData<int8_t>()[idx / 2] >> 4, value);
+        getData<uint8_t>()[idx / 2] =
+          encode_qint(getData<uint8_t>()[idx / 2] >> 4, value);
       }
     }
   }
@@ -1479,8 +1479,8 @@ public:
       ml_loge("%s", "Error: enable-fp16 is not enabled");
 #endif
     } else if (getDataType() == Tdatatype::QINT8) {
-      getData<int8_t>()[idx] *= beta;
-      getData<int8_t>()[idx] += value;
+      getData<uint8_t>()[idx] *= beta;
+      getData<uint8_t>()[idx] += value;
     }
   }
 
@@ -1953,31 +1953,117 @@ public:
    */
   Tdatatype getDataType() const { return dim.getDataType(); }
 
+  /**
+   * @brief Set output axis of the tensor
+   * @param[in] axis output axis (0: batch, 1: channel, 2: height, 3: width)
+   */
+  void setOutputAxis(int axis);
+
+  /**
+   * @brief Get output axis of the tensor
+   *
+   * @return output axis of the tensor
+   */
+  int getOutputAxis() const;
+
   /**
    * @brief     Set scale factors of the tensor
    * @param[in] scales scale factors
    */
-  void setScaleFactors(std::vector<float> scales, int idx);
+  void setScaleFactors(std::vector<float> scales);
 
   /**
-   * @brief     Get scale factors of the tensor
-   * @retval    scales scale factors
+   * @brief Get scale factors of the tensor
+   *
+   * @return scale factors of the tensor
    */
-  std::vector<float> getScaleFactors();
+  std::vector<float> getScaleFactors() const;
 
   /**
-   * @brief     Dequantize Tensor to dtype
-   * @param[in] dtype Target Tensor DataType
+   * @brief     Set output axis of the tensor
+   * @param[in] zp zero points
+   */
+  void setZeroPoints(std::vector<uint8_t> zp);
+
+  /**
+   * @brief Get zero points of the tensor
+   *
+   * @return zero points of the tensor
+   */
+  std::vector<uint8_t> getZeroPoints() const;
+
+  /**
+   * @brief     Dequantize Tensor
    * @retval    Dequantized Tensor
    */
-  Tensor dequantize(Tdatatype dtype) const;
+  template <typename T = float> Tensor dequantize() const {
+    Tdatatype dtype =
+      (typeid(T) == typeid(float)) ? Tdatatype::FP32 : Tdatatype::FP16;
+
+    Tensor t =
+      Tensor(batch(), channel(), height(), width(), getFormat(), dtype);
+
+    return dequantize<T>(t);
+  }
 
   /**
    * @brief      Dequantize Tensor to output tensor datatype
    * @param[out] output Tensor to store the result
    * @retval     Dequantized Tensor
    */
-  Tensor dequantize(Tensor &output) const;
+  template <typename T> Tensor &dequantize(Tensor &output) const {
+    if (getDataType() == Tdatatype::FP32 || getDataType() == Tdatatype::FP16) {
+      throw std::invalid_argument("Error: Tensor cannot be dequantized");
+    }
+
+    if (output.getDataType() == Tdatatype::QINT8 ||
+        output.getDataType() == Tdatatype::QINT4) {
+      throw std::invalid_argument("Error: Target datatype is quantized type");
+    }
+
+    if (getFormat() != output.getFormat())
+      throw std::invalid_argument("Error: TensorType do not match");
+
+    if (batch() != output.batch() || channel() != output.channel() ||
+        width() != output.width() || height() != output.height())
+      throw std::invalid_argument("Error: TensorDim do not match");
+
+    if (scale_factors.empty()) {
+      throw std::invalid_argument("Error: No scale factors");
+    }
+
+    int idx;
+    for (unsigned int b = 0; b < batch(); ++b) {
+      for (unsigned int c = 0; c < channel(); ++c) {
+        for (unsigned int h = 0; h < height(); ++h) {
+          for (unsigned int w = 0; w < width(); ++w) {
+            if (output_axis == 0)
+              idx = b;
+            else if (output_axis == 1)
+              idx = c;
+            else if (output_axis == 2)
+              idx = h;
+            else if (output_axis == 3)
+              idx = w;
+
+            if (getDataType() == Tdatatype::QINT8) {
+              output.setValue(
+                b, c, h, w,
+                (T)(getValue<uint8_t>(b, c, h, w) - zero_points[idx]) *
+                  scale_factors[idx]);
+            } else {
+              output.setValue(
+                b, c, h, w,
+                (T)(getValueQint4(b, c, h, w) - zero_points[idx]) *
+                  scale_factors[idx]);
+            }
+          }
+        }
+      }
+    }
+
+    return output;
+  }
 
   static constexpr float epsilon = 1e-5;
 
@@ -1990,8 +2076,9 @@ private:
   std::string name; /**< name of the tensor */
   std::shared_ptr<MemoryData> data;
   size_t offset;
-  int scale_idx;
+  int output_axis;
   std::vector<float> scale_factors;
+  std::vector<uint8_t> zero_points;
 
   /**<
    * When using shared_data with tensor, this stores the ptr of the source
@@ -2133,14 +2220,14 @@ private:
    * @param[in]  low value for last 4 bits
    * @retval     Encoded value
    */
-  int8_t encode_qint(int8_t high, int8_t low) const;
+  uint8_t encode_qint(uint8_t high, uint8_t low) const;
 
   /**
    * @brief      Decode int8 value to a int4 value
    * @param[in]  idx index to retrieve value
    * @retval     Decoded value
    */
-  int8_t decode_qint(int8_t val, bool isHigh) const;
+  uint8_t decode_qint(uint8_t val, bool isHigh) const;
 
 }; // namespace nntrainer
 
index 9531860661ce8aabb953fcce330c742842e3374f..259637a6d97d0f59db5e29296275a69d022b1826 100644 (file)
@@ -648,9 +648,9 @@ void from_string(const std::string &value, std::vector<T> &property) {
 struct TensorDataTypeInfo {
   using Enum = nntrainer::TensorDim::DataType;
   static constexpr std::initializer_list<Enum> EnumList = {
-    Enum::QINT8, Enum::FP16, Enum::FP32};
+    Enum::QINT4, Enum::QINT8, Enum::FP16, Enum::FP32};
 
-  static constexpr const char *EnumStr[] = {"QINT8", "FP16", "FP32"};
+  static constexpr const char *EnumStr[] = {"QINT4", "QINT8", "FP16", "FP32"};
 };
 
 /**
index 71783a38f6be6b16b90fa2639d0659253e5647c6..5d4ae198e74e73f4d3206013d77dfd80f837d7bd 100644 (file)
@@ -193,12 +193,12 @@ TEST(nntrainer_Tensor, Tensor_04_p) {
   int batch = 3;
   int height = 3;
   int width = 10;
-  std::vector<std::vector<std::vector<int8_t>>> in;
+  std::vector<std::vector<std::vector<uint8_t>>> in;
 
   for (int k = 0; k < batch; ++k) {
-    std::vector<std::vector<int8_t>> ttv;
+    std::vector<std::vector<uint8_t>> ttv;
     for (int i = 0; i < height; ++i) {
-      std::vector<int8_t> tv;
+      std::vector<uint8_t> tv;
       for (int j = 0; j < width; ++j) {
         tv.push_back(k * height * width + i * width + j);
       }
@@ -209,30 +209,30 @@ TEST(nntrainer_Tensor, Tensor_04_p) {
 
   nntrainer::Tensor tensor = nntrainer::Tensor(
     in, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
-  ASSERT_NE(nullptr, tensor.getData<int8_t>());
+  ASSERT_NE(nullptr, tensor.getData<uint8_t>());
 
-  if (tensor.getValue<int8_t>(0, 0, 0, 1) != 1)
+  if (tensor.getValue<uint8_t>(0, 0, 0, 1) != 1)
     status = ML_ERROR_INVALID_PARAMETER;
   EXPECT_EQ(status, ML_ERROR_NONE);
 }
 
 TEST(nntrainer_Tensor, Tensor_05_p) {
   int status = ML_ERROR_NONE;
-  std::vector<std::vector<std::vector<int8_t>>> in = {{{-8, -7}, {-6, -5}},
-                                                      {{-4, -3}, {-2, -1}},
-                                                      {{0, 1}, {2, 3}},
-                                                      {{4, 5}, {6, 7}}};
+  std::vector<std::vector<std::vector<uint8_t>>> in = {{{0, 1}, {2, 3}},
+                                                       {{4, 5}, {6, 7}},
+                                                       {{8, 9}, {10, 11}},
+                                                       {{12, 13}, {14, 15}}};
 
   nntrainer::Tensor tensor = nntrainer::Tensor(
     in, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT4});
-  ASSERT_NE(nullptr, tensor.getData<int8_t>());
+  ASSERT_NE(nullptr, tensor.getData<uint8_t>());
 
   for (size_t b = 0; b < tensor.batch(); ++b) {
     for (size_t c = 0; c < tensor.channel(); ++c) {
       for (size_t h = 0; h < tensor.height(); ++h) {
         for (size_t w = 0; w < tensor.width(); ++w) {
           size_t idx = tensor.getIndex(b, c, h, w);
-          ASSERT_EQ(idx - 8, tensor.getValueQint4(idx));
+          ASSERT_EQ(idx, tensor.getValueQint4(idx));
         }
       }
     }
@@ -243,16 +243,16 @@ TEST(nntrainer_Tensor, Tensor_06_p) {
   int status = ML_ERROR_NONE;
   nntrainer::Tensor tensor = nntrainer::Tensor(
     1, 4, 2, 2, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT4});
-  ASSERT_NE(nullptr, tensor.getData<int8_t>());
+  ASSERT_NE(nullptr, tensor.getData<uint8_t>());
 
-  tensor.setValue(-2);
+  tensor.setValue(2);
 
   for (size_t b = 0; b < tensor.batch(); ++b) {
     for (size_t c = 0; c < tensor.channel(); ++c) {
       for (size_t h = 0; h < tensor.height(); ++h) {
         for (size_t w = 0; w < tensor.width(); ++w) {
           size_t idx = tensor.getIndex(b, c, h, w);
-          ASSERT_EQ(-2, tensor.getValueQint4(idx));
+          ASSERT_EQ(2, tensor.getValueQint4(idx));
         }
       }
     }
@@ -4354,11 +4354,13 @@ TEST(nntrainer_Tensor, dequantize_01_n) {
 
   nntrainer::Tensor input(batch, channel, height, width);
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+  input.setOutputAxis(1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+  input.setZeroPoints({1, 4, 7});
 
   nntrainer::Tensor output(batch, channel, height, width);
 
-  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+  EXPECT_THROW({ input.dequantize<float>(output); }, std::invalid_argument);
 }
 
 /**
@@ -4374,11 +4376,13 @@ TEST(nntrainer_Tensor, dequantize_02_n) {
     batch + 1, channel, height + 1, width + 1,
     {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+  input.setOutputAxis(1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+  input.setZeroPoints({1, 4, 7});
 
   nntrainer::Tensor output(batch, channel, height, width);
 
-  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+  EXPECT_THROW({ input.dequantize<float>(output); }, std::invalid_argument);
 }
 
 /**
@@ -4397,7 +4401,7 @@ TEST(nntrainer_Tensor, dequantize_03_n) {
 
   nntrainer::Tensor output(batch, channel, height, width);
 
-  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+  EXPECT_THROW({ input.dequantize<float>(output); }, std::invalid_argument);
 }
 
 /**
@@ -4413,13 +4417,15 @@ TEST(nntrainer_Tensor, dequantize_04_n) {
     batch, channel, height, width,
     {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
+  input.setOutputAxis(1);
   EXPECT_THROW(
     {
-      input.setScaleFactors({2.0, 1.5, 1.0, 0.5}, 1);
+      input.setScaleFactors({2.0, 1.5, 1.0, 0.5});
     },
     std::invalid_argument);
 
-  EXPECT_NO_THROW({ input.setScaleFactors({2.0, 1.5, 1.0, 0.5}, 2); });
+  input.setOutputAxis(2);
+  EXPECT_NO_THROW({ input.setScaleFactors({2.0, 1.5, 1.0, 0.5}); });
 }
 
 /**
@@ -4435,52 +4441,21 @@ TEST(nntrainer_Tensor, dequantize_05_n) {
     batch, channel, height, width,
     {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+  input.setOutputAxis(1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+  input.setZeroPoints({1, 4, 7});
 
   nntrainer::Tensor output(
     batch, channel, height, width,
     {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
 
-  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
-}
-
-/**
- * @brief dequantize qint8 tensor
- */
-TEST(nntrainer_Tensor, dequantize_06_p) {
-  int batch = 1;
-  int channel = 3;
-  int height = 4;
-  int width = 5;
-
-  nntrainer::Tensor input(
-    batch, channel, height, width,
-    {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
-  GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k + 1);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
-
-  nntrainer::Tensor output;
-
-  EXPECT_NO_THROW({ output = input.dequantize(nntrainer::Tdatatype::FP32); });
-
-  float answer_data[] = {
-    1.5, 1.5, 1.5, 1.5, 1.5, 3,   3,   3,   3,   3,   4.5, 4.5, 4.5, 4.5, 4.5,
-    6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   7,   7,   7,   7,   7,
-    8,   8,   8,   8,   8,   9,   9,   9,   9,   9,   5.5, 5.5, 5.5, 5.5, 5.5,
-    6,   6,   6,   6,   6,   6.5, 6.5, 6.5, 6.5, 6.5, 7,   7,   7,   7,   7};
-
-  nntrainer::Tensor answer(ml::train::TensorDim(batch, channel, height, width,
-                                                {nntrainer::Tformat::NCHW,
-                                                 nntrainer::Tdatatype::FP32}),
-                           answer_data);
-
-  EXPECT_EQ(output, answer);
+  EXPECT_THROW({ input.dequantize<float>(output); }, std::invalid_argument);
 }
 
 /**
  * @brief dequantize tensor
  */
-TEST(nntrainer_Tensor, dequantize_07_p) {
+TEST(nntrainer_Tensor, dequantize_06_p) {
   size_t batch = 1;
   size_t channel = 3;
   size_t height = 4;
@@ -4492,12 +4467,14 @@ TEST(nntrainer_Tensor, dequantize_07_p) {
      height,
      width,
      {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8}},
-    true, nntrainer::Tensor::Initializer::ONES);
+    true, nntrainer::Tensor::Initializer::ZEROS);
   nntrainer::Tensor output(batch, channel, height, width);
 
   // Dequantize by channel
-  EXPECT_NO_THROW(input.setScaleFactors({-2, 2, 4}, 1));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  EXPECT_NO_THROW(input.setOutputAxis(1));
+  EXPECT_NO_THROW(input.setScaleFactors({2, -2, -4}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<float>(output); });
 
   float answer_data_1[] = {-2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2,
                            -2, -2, -2, -2, -2, -2, -2, -2, 2,  2,  2,  2,
@@ -4513,8 +4490,10 @@ TEST(nntrainer_Tensor, dequantize_07_p) {
   EXPECT_EQ(output, answer1);
 
   // Dequantize by height
-  EXPECT_NO_THROW(input.setScaleFactors({-4.2, -2, 2, 4.8}, 2));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  EXPECT_NO_THROW(input.setOutputAxis(2));
+  EXPECT_NO_THROW(input.setScaleFactors({4.2, 2, -2, -4.8}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<float>(output); });
 
   float answer_data_2[] = {
     -4.2, -4.2, -4.2, -4.2, -4.2, -2,   -2,   -2,   -2,   -2,   2,    2,
@@ -4530,8 +4509,10 @@ TEST(nntrainer_Tensor, dequantize_07_p) {
   EXPECT_EQ(output, answer2);
 
   // Dequantize by width
-  EXPECT_NO_THROW(input.setScaleFactors({-4.2, -2, 2, 4, -8}, 3));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  EXPECT_NO_THROW(input.setOutputAxis(3));
+  EXPECT_NO_THROW(input.setScaleFactors({4.2, 2, -2, -4, 8}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<float>(output); });
 
   float answer_data_3[] = {
     -4.2, -2, 2, 4, -8, -4.2, -2, 2, 4, -8, -4.2, -2, 2, 4, -8,
@@ -4562,12 +4543,14 @@ TEST(nntrainer_Tensor, dequantize_08_p) {
      height,
      width,
      {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT4}},
-    true, nntrainer::Tensor::Initializer::ONES);
+    true, nntrainer::Tensor::Initializer::ZEROS);
   nntrainer::Tensor output(batch, channel, height, width);
 
   // Dequantize by channel
-  EXPECT_NO_THROW(input.setScaleFactors({-2, 2, 4}, 1));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  EXPECT_NO_THROW(input.setOutputAxis(1));
+  EXPECT_NO_THROW(input.setScaleFactors({2, -2, -4}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<float>(output); });
 
   float answer_data_1[] = {-2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2,
                            -2, -2, -2, -2, -2, -2, -2, -2, 2,  2,  2,  2,
@@ -4583,8 +4566,10 @@ TEST(nntrainer_Tensor, dequantize_08_p) {
   EXPECT_EQ(output, answer1);
 
   // Dequantize by height
-  EXPECT_NO_THROW(input.setScaleFactors({-4.2, -2, 2, 4}, 2));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  EXPECT_NO_THROW(input.setOutputAxis(2));
+  EXPECT_NO_THROW(input.setScaleFactors({4.2, 2, -2, -4}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<float>(output); });
 
   float answer_data_2[] = {-4.2, -4.2, -4.2, -4.2, -4.2, -2, -2, -2, -2, -2,
                            2,    2,    2,    2,    2,    4,  4,  4,  4,  4,
@@ -4600,8 +4585,10 @@ TEST(nntrainer_Tensor, dequantize_08_p) {
   EXPECT_EQ(output, answer2);
 
   // Dequantize by width
-  EXPECT_NO_THROW(input.setScaleFactors({-4.2, -2, 2, 4, -8}, 3));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  EXPECT_NO_THROW(input.setOutputAxis(3));
+  EXPECT_NO_THROW(input.setScaleFactors({4.2, 2, -2, -4, 8}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<float>(output); });
 
   float answer_data_3[] = {
     -4.2, -2, 2, 4, -8, -4.2, -2, 2, 4, -8, -4.2, -2, 2, 4, -8,
index 67dc946399eed7d28bebc7a90785e04088fbb44f..529fcd9b0259015e9bba5466325397d24c7da1f1 100644 (file)
@@ -5859,13 +5859,15 @@ TEST(nntrainer_Tensor, dequantize_01_n) {
   nntrainer::Tensor input(batch, channel, height, width,
                           nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16);
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+  input.setOutputAxis(1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+  input.setZeroPoints({1, 4, 7});
 
   nntrainer::Tensor output(batch, channel, height, width,
                            nntrainer::Tformat::NCHW,
                            nntrainer::Tdatatype::FP16);
 
-  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+  EXPECT_THROW({ input.dequantize<_FP16>(output); }, std::invalid_argument);
 }
 
 /**
@@ -5881,13 +5883,16 @@ TEST(nntrainer_Tensor, dequantize_02_n) {
     batch + 1, channel, height + 1, width + 1,
     {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+
+  input.setOutputAxis(1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+  input.setZeroPoints({1, 4, 7});
 
   nntrainer::Tensor output(batch, channel, height, width,
                            nntrainer::Tformat::NCHW,
                            nntrainer::Tdatatype::FP16);
 
-  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+  EXPECT_THROW({ input.dequantize<_FP16>(output); }, std::invalid_argument);
 }
 
 /**
@@ -5908,7 +5913,7 @@ TEST(nntrainer_Tensor, dequantize_03_n) {
                            nntrainer::Tformat::NCHW,
                            nntrainer::Tdatatype::FP16);
 
-  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+  EXPECT_THROW({ input.dequantize<_FP16>(output); }, std::invalid_argument);
 }
 
 /**
@@ -5924,11 +5929,13 @@ TEST(nntrainer_Tensor, dequantize_04_p) {
     batch, channel, height, width,
     {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k + 1);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+  input.setOutputAxis(1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+  input.setZeroPoints({0, 0, 0});
 
   nntrainer::Tensor output;
 
-  EXPECT_NO_THROW({ output = input.dequantize(nntrainer::Tdatatype::FP16); });
+  EXPECT_NO_THROW({ output = input.dequantize<_FP16>(); });
 
   _FP16 answer_data[] = {
     static_cast<_FP16>(1.5), static_cast<_FP16>(1.5), static_cast<_FP16>(1.5),
@@ -5975,14 +5982,16 @@ TEST(nntrainer_Tensor, dequantize_05_p) {
      height,
      width,
      {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8}},
-    true, nntrainer::Tensor::Initializer::ONES);
+    true, nntrainer::Tensor::Initializer::ZEROS);
   nntrainer::Tensor output(batch, channel, height, width,
                            nntrainer::Tformat::NCHW,
                            nntrainer::Tdatatype::FP16);
 
   // Dequantize by channel
-  EXPECT_NO_THROW(input.setScaleFactors({-2, 2, 4}, 1));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  input.setOutputAxis(1);
+  EXPECT_NO_THROW(input.setScaleFactors({2, -2, -4}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<_FP16>(output); });
 
   _FP16 answer_data_1[] = {-2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2,
                            -2, -2, -2, -2, -2, -2, -2, -2, 2,  2,  2,  2,
@@ -5998,8 +6007,10 @@ TEST(nntrainer_Tensor, dequantize_05_p) {
   EXPECT_EQ(output, answer1);
 
   // Dequantize by height
-  EXPECT_NO_THROW(input.setScaleFactors({-4.2, -2, 2, 4.8}, 2));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  input.setOutputAxis(2);
+  EXPECT_NO_THROW(input.setScaleFactors({4.2, 2, -2, -4.8}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<_FP16>(output); });
 
   _FP16 answer_data_2[] = {static_cast<_FP16>(-4.2), static_cast<_FP16>(-4.2),
                            static_cast<_FP16>(-4.2), static_cast<_FP16>(-4.2),
@@ -6039,8 +6050,10 @@ TEST(nntrainer_Tensor, dequantize_05_p) {
   EXPECT_EQ(output, answer2);
 
   // Dequantize by width
-  EXPECT_NO_THROW(input.setScaleFactors({-4.2, -2, 2, 4, -8}, 3));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  input.setOutputAxis(3);
+  EXPECT_NO_THROW(input.setScaleFactors({4.2, 2, -2, -4, 8}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<_FP16>(output); });
 
   _FP16 answer_data_3[] = {static_cast<_FP16>(-4.2), static_cast<_FP16>(-2),
                            static_cast<_FP16>(2),    static_cast<_FP16>(4),
@@ -6096,14 +6109,16 @@ TEST(nntrainer_Tensor, dequantize_06_p) {
      height,
      width,
      {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT4}},
-    true, nntrainer::Tensor::Initializer::ONES);
+    true, nntrainer::Tensor::Initializer::ZEROS);
   nntrainer::Tensor output(batch, channel, height, width,
                            nntrainer::Tformat::NCHW,
                            nntrainer::Tdatatype::FP16);
 
   // Dequantize by channel
-  EXPECT_NO_THROW(input.setScaleFactors({-2, 2, 4}, 1));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  input.setOutputAxis(1);
+  EXPECT_NO_THROW(input.setScaleFactors({2, -2, -4}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<_FP16>(output); });
 
   _FP16 answer_data_1[] = {-2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2,
                            -2, -2, -2, -2, -2, -2, -2, -2, 2,  2,  2,  2,
@@ -6119,8 +6134,10 @@ TEST(nntrainer_Tensor, dequantize_06_p) {
   EXPECT_EQ(output, answer1);
 
   // Dequantize by height
-  EXPECT_NO_THROW(input.setScaleFactors({-4.2, -2, 2, 4}, 2));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  input.setOutputAxis(2);
+  EXPECT_NO_THROW(input.setScaleFactors({4.2, 2, -2, -4}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<_FP16>(output); });
 
   _FP16 answer_data_2[] = {static_cast<_FP16>(-4.2), static_cast<_FP16>(-4.2),
                            static_cast<_FP16>(-4.2), static_cast<_FP16>(-4.2),
@@ -6160,8 +6177,10 @@ TEST(nntrainer_Tensor, dequantize_06_p) {
   EXPECT_EQ(output, answer2);
 
   // Dequantize by width
-  EXPECT_NO_THROW(input.setScaleFactors({-4.2, -2, 2, 4, -8}, 3));
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  input.setOutputAxis(3);
+  EXPECT_NO_THROW(input.setScaleFactors({4.2, 2, -2, -4, 8}));
+  EXPECT_NO_THROW(input.setZeroPoints({1, 1, 1, 1, 1}));
+  EXPECT_NO_THROW({ input.dequantize<_FP16>(output); });
 
   _FP16 answer_data_3[] = {static_cast<_FP16>(-4.2), static_cast<_FP16>(-2),
                            static_cast<_FP16>(2),    static_cast<_FP16>(4),
index e05ba068e873338c60707a28234f3542e938cf48..7777cee0e8c106d05e34af89c73df881ffe50257 100644 (file)
@@ -4688,13 +4688,15 @@ TEST(nntrainer_Tensor, dequantize_01_n) {
     batch, channel, height, width,
     {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8});
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+  input.setOutputAxis(1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+  input.setZeroPoints({1, 0, 3});
 
   nntrainer::Tensor output(
     batch, channel, height, width,
     {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::FP32});
 
-  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+  EXPECT_THROW({ input.dequantize<float>(output); }, std::invalid_argument);
 }
 
 /**
@@ -4710,13 +4712,15 @@ TEST(nntrainer_Tensor, dequantize_02_n) {
     batch, channel, height, width,
     {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::QINT8});
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+  input.setOutputAxis(1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+  input.setZeroPoints({1, 0, 3});
 
   nntrainer::Tensor output(
     batch, channel, height, width,
     {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32});
 
-  EXPECT_THROW({ input.dequantize(output); }, std::invalid_argument);
+  EXPECT_THROW({ input.dequantize<float>(output); }, std::invalid_argument);
 }
 
 /**
@@ -4732,18 +4736,20 @@ TEST(nntrainer_Tensor, dequantize_03_p) {
     batch, channel, height, width,
     {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::QINT8});
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k + 1);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+  input.setOutputAxis(1);
+  input.setScaleFactors({1.5, 1.0, 0.5});
+  input.setZeroPoints({5, 5, 5});
 
   nntrainer::Tensor output;
   output.getDim().setFormat(nntrainer::Tformat::NHWC);
 
-  EXPECT_NO_THROW({ output = input.dequantize(nntrainer::Tdatatype::FP32); });
+  EXPECT_NO_THROW({ output = input.dequantize<float>(); });
 
-  float answer_data[] = {1.5, 6, 5.5, 1.5, 6, 5.5, 1.5, 6, 5.5, 1.5, 6, 5.5,
-                         1.5, 6, 5.5, 3,   7, 6,   3,   7, 6,   3,   7, 6,
-                         3,   7, 6,   3,   7, 6,   4.5, 8, 6.5, 4.5, 8, 6.5,
-                         4.5, 8, 6.5, 4.5, 8, 6.5, 4.5, 8, 6.5, 6,   9, 7,
-                         6,   9, 7,   6,   9, 7,   6,   9, 7,   6,   9, 7};
+  float answer_data[] = {
+    -6,   1, 3,   -6,   1, 3,   -6,   1, 3,   -6,   1, 3,   -6,   1, 3,
+    -4.5, 2, 3.5, -4.5, 2, 3.5, -4.5, 2, 3.5, -4.5, 2, 3.5, -4.5, 2, 3.5,
+    -3,   3, 4,   -3,   3, 4,   -3,   3, 4,   -3,   3, 4,   -3,   3, 4,
+    -1.5, 4, 4.5, -1.5, 4, 4.5, -1.5, 4, 4.5, -1.5, 4, 4.5, -1.5, 4, 4.5};
 
   nntrainer::Tensor answer(ml::train::TensorDim(batch, channel, height, width,
                                                 {nntrainer::Tformat::NHWC,
@@ -4766,19 +4772,23 @@ TEST(nntrainer_Tensor, dequantize_04_p) {
     batch, channel, height, width,
     {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::QINT8});
   GEN_TEST_INPUT(input, i * (batch * height) + j * (width) + k + 1);
-  input.setScaleFactors({1.5, 1.0, 0.5}, 1);
+  input.setOutputAxis(2);
+  input.setScaleFactors({2.5, 2.0, 1.5, 1.0});
+  input.setZeroPoints({8, 8, 8, 8});
 
   nntrainer::Tensor output(
     batch, channel, height, width,
     {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::FP32});
 
-  EXPECT_NO_THROW({ input.dequantize(output); });
+  EXPECT_NO_THROW({ input.dequantize<float>(output); });
 
-  float answer_data[] = {1.5, 6, 5.5, 1.5, 6, 5.5, 1.5, 6, 5.5, 1.5, 6, 5.5,
-                         1.5, 6, 5.5, 3,   7, 6,   3,   7, 6,   3,   7, 6,
-                         3,   7, 6,   3,   7, 6,   4.5, 8, 6.5, 4.5, 8, 6.5,
-                         4.5, 8, 6.5, 4.5, 8, 6.5, 4.5, 8, 6.5, 6,   9, 7,
-                         6,   9, 7,   6,   9, 7,   6,   9, 7,   6,   9, 7};
+  float answer_data[] = {
+    -17.5, -5, 7.5, -17.5, -5, 7.5, -17.5, -5, 7.5, -17.5, -5, 7.5,
+    -17.5, -5, 7.5, -12,   -2, 8,   -12,   -2, 8,   -12,   -2, 8,
+    -12,   -2, 8,   -12,   -2, 8,   -7.5,  0,  7.5, -7.5,  0,  7.5,
+    -7.5,  0,  7.5, -7.5,  0,  7.5, -7.5,  0,  7.5, -4,    1,  6,
+    -4,    1,  6,   -4,    1,  6,   -4,    1,  6,   -4,    1,  6,
+  };
 
   nntrainer::Tensor answer(ml::train::TensorDim(batch, channel, height, width,
                                                 {nntrainer::Tformat::NHWC,
@@ -4803,11 +4813,15 @@ TEST(nntrainer_Tensor, dequantize_05_p) {
      height,
      width,
      {nntrainer::Tformat::NHWC, nntrainer::Tdatatype::QINT4}},
-    true, nntrainer::Tensor::Initializer::ONES);
-  input.setScaleFactors({-8, -6, -4, -2, -1, 1, 2, 4, 6, 7}, 1);
+    true, nntrainer::Tensor::Initializer::ZEROS);
+
+  input.setOutputAxis(1);
+  input.setScaleFactors({8, 6, 4, 2, 1, -1, -2, -4, -6, -7});
+  input.setZeroPoints({1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
+
   nntrainer::Tensor output;
 
-  EXPECT_NO_THROW({ output = input.dequantize(nntrainer::Tdatatype::FP32); });
+  EXPECT_NO_THROW({ output = input.dequantize<float>(); });
 
   float answer_data[] = {-8, -6, -4, -2, -1, 1, 2, 4, 6, 7,
                          -8, -6, -4, -2, -1, 1, 2, 4, 6, 7};
index ee755384ff1a366e39e54329089158f6f87a23c9..fa57141c0880c7e848c253222a5ceecd7216cb39 100644 (file)
@@ -488,10 +488,10 @@ TEST(TensorPool, validate_memory_reuse_01_p) {
 
   EXPECT_NO_THROW(pool.allocate());
 
-  EXPECT_EQ(t1->getAddress<float>(0), (float *)t2->getAddress<int8_t>(0));
-  EXPECT_EQ(t1->getAddress<float>(1), (float *)t3->getAddress<int8_t>(0));
-  EXPECT_EQ(t1->getAddress<float>(2), (float *)t4->getAddress<int8_t>(0));
-  EXPECT_EQ(t1->getAddress<float>(3), (float *)t5->getAddress<int8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(0), (float *)t2->getAddress<uint8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(1), (float *)t3->getAddress<uint8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(2), (float *)t4->getAddress<uint8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(3), (float *)t5->getAddress<uint8_t>(0));
 
   EXPECT_NO_THROW(pool.deallocate());
 }
@@ -549,10 +549,10 @@ TEST(TensorPool, validate_memory_reuse_02_p) {
 
   EXPECT_NO_THROW(pool.allocate());
 
-  EXPECT_EQ(t1->getAddress<float>(0), (float *)t2->getAddress<int8_t>(0));
-  EXPECT_EQ(t1->getAddress<float>(1), (float *)t3->getAddress<int8_t>(0));
-  EXPECT_EQ(t1->getAddress<float>(2), (float *)t4->getAddress<int8_t>(0));
-  EXPECT_EQ(t1->getAddress<float>(3), (float *)t5->getAddress<int8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(0), (float *)t2->getAddress<uint8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(1), (float *)t3->getAddress<uint8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(2), (float *)t4->getAddress<uint8_t>(0));
+  EXPECT_EQ(t1->getAddress<float>(3), (float *)t5->getAddress<uint8_t>(0));
 
   EXPECT_NO_THROW(pool.deallocate());
 }