Make Compatible Shapes for CPU Add (#6540)
author이춘석/On-Device Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Fri, 20 Sep 2019 04:04:58 +0000 (13:04 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Fri, 20 Sep 2019 04:04:58 +0000 (13:04 +0900)
- This commit addes routine for making compatible shapes for CPU Add
  kernel

Signed-off-by: Chunseok Lee <chunseok.lee@samsung.com>
runtimes/neurun/backend/cpu/ShapeFixer.cc

index 741f07d..568666f 100644 (file)
@@ -77,12 +77,19 @@ void ShapeFixer::visit(const model::operation::AddNode &node)
   const auto lhs_index{node.getInputs().at(model::operation::AddNode::Input::LHS)};
   const auto rhs_index{node.getInputs().at(model::operation::AddNode::Input::RHS)};
 
-  // Broadcasting and quantization
-  if (!(_ctx.at(lhs_index).shape() == _ctx.at(rhs_index).shape()) ||
-      _ctx.at(lhs_index).typeInfo().type() == model::DataType::QUANT8_ASYMM)
+  // Quantization : not supported
+  if (_ctx.at(lhs_index).typeInfo().type() == model::DataType::QUANT8_ASYMM)
   {
     throw std::runtime_error{"NYI"};
   }
+  // Broadcast
+  if (!(_ctx.at(lhs_index).shape() == _ctx.at(rhs_index).shape()))
+  {
+    const auto broadcast_rank =
+        std::max(_ctx.at(lhs_index).shape().rank(), _ctx.at(rhs_index).shape().rank());
+    const_cast<::neurun::model::Shape &>(_ctx.at(lhs_index).shape()).extendRank(broadcast_rank);
+    const_cast<::neurun::model::Shape &>(_ctx.at(rhs_index).shape()).extendRank(broadcast_rank);
+  }
 }
 
 void ShapeFixer::visit(const model::operation::PermuteNode &) { /* DO NOTHING */}