From b0de0aa23cb5ba0506c7d1ef3df79b24aea93cf1 Mon Sep 17 00:00:00 2001 From: skykongkong8 Date: Wed, 19 Jul 2023 09:52:28 +0900 Subject: [PATCH] [WIP] [Tensor] Add __fp16 to Tensor member functions * add if-elsif code block to each Tensor member function * fix trivial missed functions Signed-off-by: skykongkong8 --- nntrainer/tensor/tensor.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 7f8ff71..6e45c4d 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -384,9 +384,9 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output, for (unsigned int b = 0; b < batch(); ++b) { for (unsigned int c = 0; c < channel(); ++c) { for (unsigned int h = 0; h < height(); ++h) { - float *out_data = output.getAddress<__fp16>(b, c, h, 0); - const float *m_data = m.getAddress<__fp16>(b, c, h, 0); - const float *in_data = getAddress<__fp16>(b, c, h, 0); + __fp16 *out_data = output.getAddress<__fp16>(b, c, h, 0); + const __fp16 *m_data = m.getAddress<__fp16>(b, c, h, 0); + const __fp16 *in_data = getAddress<__fp16>(b, c, h, 0); std::transform(in_data, in_data + width(), m_data, out_data, std::multiplies<__fp16>()); } @@ -834,7 +834,7 @@ Tensor &Tensor::add(float const &value, Tensor &out) const { return apply(f, out); } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) { auto f = std::bind(std::plus<__fp16>(), std::placeholders::_1, value); - return apply(f, out) + return apply(f, out); } return out; } @@ -1181,10 +1181,10 @@ std::vector Tensor::split(std::vector sizes, int axis) { return ret; } if (getDataType() == ml::train::TensorDim::DataType::FP16) { - auto iter_value = [this, is_format_nchw]( - std::array &loc, - const std::array &end_loc, - const std::array &reset_dim_arr) -> float & { + auto iter_value = + [this, is_format_nchw]( + std::array &loc, const std::array &end_loc, + const std::array &reset_dim_arr) -> __fp16 & { auto &value = (is_format_nchw) ? getValue<__fp16>(loc[0], loc[1], loc[2], loc[3]) : getValue<__fp16>(loc[0], loc[3], loc[1], loc[2]); @@ -1567,8 +1567,8 @@ Tensor Tensor::sum_by_batch() const { sgemv(CblasRowMajor, CblasNoTrans, batch, feat_len, 1, data, feat_len, ones.getData(), 1, 0.0, rdata, 1); } else if (getDataType() == ml::train::TensorDim::DataType::FP16) { - const __fp16 *data = getData(); - __fp16 *rdata = ret.getData(); + const __fp16 *data = getData<__fp16>(); + __fp16 *rdata = ret.getData<__fp16>(); Tensor ones(1, 1, 1, feat_len, this->getFormat()); ones.setValue((__fp16)1.0); -- 2.7.4