[Tensor] check data allocation in add/multiply_strided
authorDonghyeon Jeong <dhyeon.jeong@samsung.com>
Thu, 20 Jul 2023 04:19:50 +0000 (13:19 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 21 Aug 2023 06:29:23 +0000 (15:29 +0900)
Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
nntrainer/tensor/tensor.cpp

index be8c681..0bf5290 100644 (file)
@@ -337,16 +337,25 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
     throw std::invalid_argument(
       "Strided multiplication does not support broadcasting");
 
+  if (getDataType() == Tdatatype::FP32) {
+    NNTR_THROW_IF(getData<float>() == nullptr, std::invalid_argument)
+      << getName() << " is not allocated";
+    NNTR_THROW_IF(m.getData<float>() == nullptr, std::invalid_argument)
+      << m.getName() << " is not allocated";
+    NNTR_THROW_IF(output.getData<float>() == nullptr, std::invalid_argument)
+      << output.getName() << " is not allocated";
+  } else if (getDataType() == Tdatatype::FP16) {
+    NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument)
+      << getName() << " is not allocated";
+    NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument)
+      << m.getName() << " is not allocated";
+    NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument)
+      << output.getName() << " is not allocated";
+  }
+
   // Format NCHW Case
   if (this->getFormat() == Tformat::NCHW) {
     if (getDataType() == Tdatatype::FP32) {
-      NNTR_THROW_IF(getData<float>() == nullptr, std::invalid_argument)
-        << getName() << " is not allocated";
-      NNTR_THROW_IF(m.getData<float>() == nullptr, std::invalid_argument)
-        << m.getName() << " is not allocated";
-      NNTR_THROW_IF(output.getData<float>() == nullptr, std::invalid_argument)
-        << output.getName() << " is not allocated";
-
       if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 ||
           beta != 0.0) {
         for (unsigned int b = 0; b < batch(); ++b) {
@@ -377,13 +386,6 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output,
         }
       }
     } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-      NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument)
-        << getName() << " is not allocated";
-      NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument)
-        << m.getName() << " is not allocated";
-      NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument)
-        << output.getName() << " is not allocated";
-
       if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 ||
           beta != 0.0) {
         for (unsigned int b = 0; b < batch(); ++b) {
@@ -503,6 +505,23 @@ Tensor &Tensor::add_strided(Tensor const &m, Tensor &output,
   if (size() != m.size() || size() != output.size())
     throw std::invalid_argument(
       "Strided addition does not support broadcasting");
+
+  if (getDataType() == Tdatatype::FP32) {
+    NNTR_THROW_IF(getData<float>() == nullptr, std::invalid_argument)
+      << getName() << " is not allocated";
+    NNTR_THROW_IF(m.getData<float>() == nullptr, std::invalid_argument)
+      << m.getName() << " is not allocated";
+    NNTR_THROW_IF(output.getData<float>() == nullptr, std::invalid_argument)
+      << output.getName() << " is not allocated";
+  } else if (getDataType() == Tdatatype::FP16) {
+    NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument)
+      << getName() << " is not allocated";
+    NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument)
+      << m.getName() << " is not allocated";
+    NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument)
+      << output.getName() << " is not allocated";
+  }
+
   // Format NCHW Case
   if (this->getFormat() == Tformat::NCHW) {
     if (getDataType() == Tdatatype::FP32) {