From 27df22796e0802b749fa65c4f397a354bcd9e085 Mon Sep 17 00:00:00 2001 From: Donghyeon Jeong Date: Thu, 20 Jul 2023 13:19:50 +0900 Subject: [PATCH] [Tensor] check data allocation in add/multiply_strided Signed-off-by: Donghyeon Jeong --- nntrainer/tensor/tensor.cpp | 47 +++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index be8c681..0bf5290 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -337,16 +337,25 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output, throw std::invalid_argument( "Strided multiplication does not support broadcasting"); + if (getDataType() == Tdatatype::FP32) { + NNTR_THROW_IF(getData() == nullptr, std::invalid_argument) + << getName() << " is not allocated"; + NNTR_THROW_IF(m.getData() == nullptr, std::invalid_argument) + << m.getName() << " is not allocated"; + NNTR_THROW_IF(output.getData() == nullptr, std::invalid_argument) + << output.getName() << " is not allocated"; + } else if (getDataType() == Tdatatype::FP16) { + NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument) + << getName() << " is not allocated"; + NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument) + << m.getName() << " is not allocated"; + NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument) + << output.getName() << " is not allocated"; + } + // Format NCHW Case if (this->getFormat() == Tformat::NCHW) { if (getDataType() == Tdatatype::FP32) { - NNTR_THROW_IF(getData() == nullptr, std::invalid_argument) - << getName() << " is not allocated"; - NNTR_THROW_IF(m.getData() == nullptr, std::invalid_argument) - << m.getName() << " is not allocated"; - NNTR_THROW_IF(output.getData() == nullptr, std::invalid_argument) - << output.getName() << " is not allocated"; - if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 || beta != 0.0) { for (unsigned int b = 0; b < batch(); ++b) { @@ -377,13 +386,6 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output, } } } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) { - NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument) - << getName() << " is not allocated"; - NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument) - << m.getName() << " is not allocated"; - NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument) - << output.getName() << " is not allocated"; - if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 || beta != 0.0) { for (unsigned int b = 0; b < batch(); ++b) { @@ -503,6 +505,23 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output, if (size() != m.size() || size() != output.size()) throw std::invalid_argument( "Strided addition does not support broadcasting"); + + if (getDataType() == Tdatatype::FP32) { + NNTR_THROW_IF(getData() == nullptr, std::invalid_argument) + << getName() << " is not allocated"; + NNTR_THROW_IF(m.getData() == nullptr, std::invalid_argument) + << m.getName() << " is not allocated"; + NNTR_THROW_IF(output.getData() == nullptr, std::invalid_argument) + << output.getName() << " is not allocated"; + } else if (getDataType() == Tdatatype::FP16) { + NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument) + << getName() << " is not allocated"; + NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument) + << m.getName() << " is not allocated"; + NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument) + << output.getName() << " is not allocated"; + } + // Format NCHW Case if (this->getFormat() == Tformat::NCHW) { if (getDataType() == Tdatatype::FP32) { -- 2.7.4