if (size() != from.size())
throw std::invalid_argument("Size of tensor to copy must match");
- if (getDataType() != from.getDataType())
- throw std::invalid_argument("Data type of tensor to copy must match");
-
- if (getDataType() == ml::train::TensorDim::DataType::FP32) {
- copy(from.getData<float>());
- } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
+ if (getDataType() == from.getDataType()) {
+ if (getDataType() == ml::train::TensorDim::DataType::FP32) {
+ copy(from.getData<float>());
+ } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
- copy(from.getData<_FP16>());
+ copy(from.getData<_FP16>());
#else
- throw std::invalid_argument("Error: enable-fp16 is not enabled");
+ throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+ } else {
+ copy(from.getData<uint8_t>());
+ }
+ } else {
+ if (getDataType() == ml::train::TensorDim::DataType::FP32) {
+ if (from.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ scopy(size(), from.getData<_FP16>(), 1, getData<float>(), 1);
+#else
+ throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
+ } else if (from.getDataType() == ml::train::TensorDim::DataType::QINT8) {
+ scopy_int8_to_float32(from.size(), from.getData<uint8_t>(), 1,
+ getData<float>(), 1);
+ } else if (from.getDataType() == ml::train::TensorDim::DataType::QINT4) {
+ scopy_int4_to_float32((from.size() + 1) / 2, from.getData<uint8_t>(), 1,
+ getData<float>(), 1);
+ }
+ } else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ if (from.getDataType() == ml::train::TensorDim::DataType::QINT8) {
+ scopy_int8_to_float16(from.size(), from.getData<uint8_t>(), 1,
+ getData<_FP16>(), 1);
+ } else if (from.getDataType() == ml::train::TensorDim::DataType::QINT4) {
+ scopy_int4_to_float16((from.size() + 1) / 2, from.getData<uint8_t>(), 1,
+ getData<_FP16>(), 1);
+ }
+#else
+ throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+ }
}
}
<< " is too big. It cannot be represented by std::streamsize";
if (this->getDataType() == ml::train::TensorDim::DataType::FP32) {
-
- // std::vector<_FP16> temp(size());
- // for (unsigned int i = 0; i < size(); ++i) {
- // temp[i] = static_cast<_FP16>(getData()[i]);
- // }
-
- // checkedWrite(file, (char *)temp.data(),
- // static_cast<std::streamsize>(size() * sizeof(_FP16)),
- // "[Tensor::save] operation failed");
-
checkedWrite(file, (char *)getData(), sz,
"[Tensor::save] operation failed");
} else if (this->getDataType() == ml::train::TensorDim::DataType::FP16) {
std::vector<uint8_t> Tensor::getZeroPoints() const { return zero_points; }
-void Tensor::flate(Tensor &output) const {
- if (output.getDataType() == Tdatatype::FP32) {
- float *o_data = output.getData<float>();
- const uint8_t *data = getData<uint8_t>();
-
- if (getDataType() == Tdatatype::QINT4) {
- for (unsigned int i = 0; i < (output.getDim().getDataLen() + 1) / 2;
- ++i) {
- unsigned int idx = i * 2;
- o_data[idx] = data[i] >> 4;
- if (idx + 1 < output.getDim().getDataLen())
- o_data[idx + 1] = data[i] & 0x0f;
- }
- } else if (getDataType() == Tdatatype::QINT8) {
- for (unsigned int i = 0; i < output.getDim().getDataLen(); ++i) {
- o_data[i] = data[i];
- }
- }
- } else if (output.getDataType() == Tdatatype::FP16) {
-#ifdef ENABLE_FP16
- _FP16 *o_data = output.getData<_FP16>();
- const uint8_t *data = getData<uint8_t>();
-
- if (getDataType() == Tdatatype::QINT8) {
- for (unsigned int i = 0; i < output.getDim().getDataLen(); ++i) {
- o_data[i] = data[i];
- }
- }
-#else
- throw std::invalid_argument("enble-fp16 is not set");
-#endif
- }
-}
-
void Tensor::dequantize(Tensor &output, unsigned int axis) const {
if (getDataType() == Tdatatype::FP32 || getDataType() == Tdatatype::FP16) {
throw std::invalid_argument("Error: Tensor cannot be dequantized");
size_t h = (axis == 2) ? zero_points.size() : 1;
size_t w = (axis == 3) ? zero_points.size() : 1;
+ output.copyData(*this);
+
if (output.getDataType() == Tdatatype::FP16) {
#ifdef ENABLE_FP16
- if (getDataType() == Tdatatype::QINT4) {
- scopy_int4_to_float16((size() + 1) / 2, getData<uint8_t>(), 1,
- output.getData<_FP16>(), 1);
- } else if (getDataType() == Tdatatype::QINT8) {
- // @todo scopy for qint8
- flate(output);
- }
-
std::vector<_FP16> zero_points_16(zero_points.begin(), zero_points.end());
Tensor zero_points_fp16_tensor(
{{b, c, h, w}, {getFormat(), Tdatatype::FP16}}, zero_points_16.data());
throw std::invalid_argument("enble-fp16 is not set");
#endif
} else if (output.getDataType() == Tdatatype::FP32) {
- // @todo need scopy for uint8 to float
- flate(output);
-
std::vector<float> zero_points_32(zero_points.begin(), zero_points.end());
Tensor zero_points_fp32_tensor(
{{b, c, h, w}, {getFormat(), Tdatatype::FP32}}, zero_points_32.data());