throw std::invalid_argument(
"Strided multiplication does not support broadcasting");
+ if (getDataType() == Tdatatype::FP32) {
+ NNTR_THROW_IF(getData<float>() == nullptr, std::invalid_argument)
+ << getName() << " is not allocated";
+ NNTR_THROW_IF(m.getData<float>() == nullptr, std::invalid_argument)
+ << m.getName() << " is not allocated";
+ NNTR_THROW_IF(output.getData<float>() == nullptr, std::invalid_argument)
+ << output.getName() << " is not allocated";
+ } else if (getDataType() == Tdatatype::FP16) {
+ NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument)
+ << getName() << " is not allocated";
+ NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument)
+ << m.getName() << " is not allocated";
+ NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument)
+ << output.getName() << " is not allocated";
+ }
+
// Format NCHW Case
if (this->getFormat() == Tformat::NCHW) {
if (getDataType() == Tdatatype::FP32) {
- NNTR_THROW_IF(getData<float>() == nullptr, std::invalid_argument)
- << getName() << " is not allocated";
- NNTR_THROW_IF(m.getData<float>() == nullptr, std::invalid_argument)
- << m.getName() << " is not allocated";
- NNTR_THROW_IF(output.getData<float>() == nullptr, std::invalid_argument)
- << output.getName() << " is not allocated";
-
if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 ||
beta != 0.0) {
for (unsigned int b = 0; b < batch(); ++b) {
}
}
} else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
- NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument)
- << getName() << " is not allocated";
- NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument)
- << m.getName() << " is not allocated";
- NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument)
- << output.getName() << " is not allocated";
-
if (strides[3] != 1 || m.strides[3] != 1 || output.strides[3] != 1 ||
beta != 0.0) {
for (unsigned int b = 0; b < batch(); ++b) {
if (size() != m.size() || size() != output.size())
throw std::invalid_argument(
"Strided addition does not support broadcasting");
+
+ if (getDataType() == Tdatatype::FP32) {
+ NNTR_THROW_IF(getData<float>() == nullptr, std::invalid_argument)
+ << getName() << " is not allocated";
+ NNTR_THROW_IF(m.getData<float>() == nullptr, std::invalid_argument)
+ << m.getName() << " is not allocated";
+ NNTR_THROW_IF(output.getData<float>() == nullptr, std::invalid_argument)
+ << output.getName() << " is not allocated";
+ } else if (getDataType() == Tdatatype::FP16) {
+ NNTR_THROW_IF(getData<__fp16>() == nullptr, std::invalid_argument)
+ << getName() << " is not allocated";
+ NNTR_THROW_IF(m.getData<__fp16>() == nullptr, std::invalid_argument)
+ << m.getName() << " is not allocated";
+ NNTR_THROW_IF(output.getData<__fp16>() == nullptr, std::invalid_argument)
+ << output.getName() << " is not allocated";
+ }
+
// Format NCHW Case
if (this->getFormat() == Tformat::NCHW) {
if (getDataType() == Tdatatype::FP32) {