From: Donghyeon Jeong Date: Thu, 20 Jul 2023 04:19:50 +0000 (+0900) Subject: [Tensor] check data allocation in add/multiply_strided X-Git-Tag: accepted/tizen/8.0/unified/20231005.093407~91 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=27df22796e0802b749fa65c4f397a354bcd9e085;p=platform%2Fcore%2Fml%2Fnntrainer.git [Tensor] check data allocation in add/multiply_strided Signed-off-by: Donghyeon Jeong --- diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index be8c681..0bf5290 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -337,16 +337,25 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output, throw std::invalid_argument( "Strided multiplication does not support broadcasting"); + if (getDataType() == Tdatatype::FP32) { + NNTR_THROW_IF(getData() == nullptr, std::invalid_argument) + << getName() << " is not allocated"; + NNTR_THROW_IF(m.getData() == nullptr, std::invalid_argument) + << m.getName() << " is not allocated"; + NNTR_THROW_IF(output.getData() == nullptr, std::invalid_argument) + << output.getName() << " is not allocated"; + } else if (getDataType() == Tdatatype::FP16) { + NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument) + << getName() << " is not allocated"; + NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument) + << m.getName() << " is not allocated"; + NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument) + << output.getName() << " is not allocated"; + } + // Format NCHW Case if (this->getFormat() == Tformat::NCHW) { if (getDataType() == Tdatatype::FP32) { - NNTR_THROW_IF(getData() == nullptr, std::invalid_argument) - << getName() << " is not allocated"; - NNTR_THROW_IF(m.getData() == nullptr, std::invalid_argument) - << m.getName() << " is not allocated"; - NNTR_THROW_IF(output.getData() == nullptr, std::invalid_argument) - << output.getName() << " is not allocated"; - if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 || beta != 0.0) { for (unsigned int b = 0; b < batch(); ++b) { @@ -377,13 +386,6 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output, } } } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) { - NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument) - << getName() << " is not allocated"; - NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument) - << m.getName() << " is not allocated"; - NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument) - << output.getName() << " is not allocated"; - if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 || beta != 0.0) { for (unsigned int b = 0; b < batch(); ++b) { @@ -503,6 +505,23 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output, if (size() != m.size() || size() != output.size()) throw std::invalid_argument( "Strided addition does not support broadcasting"); + + if (getDataType() == Tdatatype::FP32) { + NNTR_THROW_IF(getData() == nullptr, std::invalid_argument) + << getName() << " is not allocated"; + NNTR_THROW_IF(m.getData() == nullptr, std::invalid_argument) + << m.getName() << " is not allocated"; + NNTR_THROW_IF(output.getData() == nullptr, std::invalid_argument) + << output.getName() << " is not allocated"; + } else if (getDataType() == Tdatatype::FP16) { + NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument) + << getName() << " is not allocated"; + NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument) + << m.getName() << " is not allocated"; + NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument) + << output.getName() << " is not allocated"; + } + // Format NCHW Case if (this->getFormat() == Tformat::NCHW) { if (getDataType() == Tdatatype::FP32) {