Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / OperationValidator.cc
index 5c545ae..4449631 100644 (file)
@@ -41,6 +41,21 @@ OperationValidator::OperationValidator(const ir::Graph &graph)
 {
 }
 
+void OperationValidator::checkUnaryOp(const ir::Operation &node)
+{
+  const auto output_index{node.getOutputs().at(0)};
+  const auto input_index{node.getInputs().at(0)};
+
+  // Check if I/O types match
+  OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
+
+  if (_ctx.at(output_index).info().isDynamic())
+    return;
+
+  // Check if I/O shapes match
+  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
+}
+
 void OperationValidator::operator()()
 {
   // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when
@@ -53,16 +68,7 @@ void OperationValidator::operator()()
       [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
 }
 
-void OperationValidator::visit(const ir::operation::Abs &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  if (_ctx.at(output_index).info().isDynamic())
-    return;
-
-  const auto input_index{node.getInputs().at(0)};
-
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Abs &node) { checkUnaryOp(node); }
 
 void OperationValidator::visit(const ir::operation::AvgPool2D &node)
 {
@@ -292,17 +298,7 @@ void OperationValidator::visit(const ir::operation::RNN &node)
               num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
 }
 
-void OperationValidator::visit(const ir::operation::Round &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  const auto input_index{node.getInputs().at(ir::operation::Round::Input::INPUT)};
-
-  OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
-
-  if (_ctx.at(output_index).info().isDynamic())
-    return;
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Round &node) { checkUnaryOp(node); }
 
 void OperationValidator::visit(const ir::operation::SpaceToBatchND &node)
 {
@@ -393,17 +389,7 @@ void OperationValidator::visit(const ir::operation::EmbeddingLookup &node)
   }
 }
 
-void OperationValidator::visit(const ir::operation::Exp &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  const auto input_index{node.getInputs().at(ir::operation::Exp::Input::INPUT)};
-
-  OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
-
-  if (_ctx.at(output_index).info().isDynamic())
-    return;
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Exp &node) { checkUnaryOp(node); }
 
 void OperationValidator::visit(const ir::operation::ExpandDims &node)
 {
@@ -419,17 +405,7 @@ void OperationValidator::visit(const ir::operation::ExpandDims &node)
   OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1);
 }
 
-void OperationValidator::visit(const ir::operation::Floor &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  const auto input_index{node.getInputs().at(ir::operation::Floor::Input::INPUT)};
-
-  OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
-
-  if (_ctx.at(output_index).info().isDynamic())
-    return;
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Floor &node) { checkUnaryOp(node); }
 
 void OperationValidator::visit(const ir::operation::HashtableLookup &node)
 {
@@ -789,6 +765,25 @@ void OperationValidator::visit(const ir::operation::LSTM &node)
   }
 }
 
+void OperationValidator::visit(const ir::operation::L2Normalization &node)
+{
+  const auto ofm_index{node.getOutputs().at(0)};
+  if (_ctx.at(ofm_index).info().isDynamic())
+    return;
+
+  const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
+
+  auto ifm_shape = _ctx.at(ifm_index).shape();
+  auto ofm_shape = _ctx.at(ofm_index).shape();
+
+  OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
+
+  for (auto i = 0; i < ifm_shape.rank(); i++)
+  {
+    OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
+  }
+}
+
 void OperationValidator::visit(const ir::operation::Unpack &node)
 {
   const auto num{node.param().num};
@@ -904,45 +899,39 @@ void OperationValidator::visit(const ir::operation::Split &node)
   OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0);
 }
 
-void OperationValidator::visit(const ir::operation::Cos &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  if (_ctx.at(output_index).info().isDynamic())
-    return;
+void OperationValidator::visit(const ir::operation::Cos &node) { checkUnaryOp(node); }
 
-  const auto input_index{node.getInputs().at(0)};
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
-
-void OperationValidator::visit(const ir::operation::Sin &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  if (_ctx.at(output_index).info().isDynamic())
-    return;
+void OperationValidator::visit(const ir::operation::Sin &node) { checkUnaryOp(node); }
 
-  const auto input_index{node.getInputs().at(0)};
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::RSQRT &node) { checkUnaryOp(node); }
 
-void OperationValidator::visit(const ir::operation::RSQRT &node)
+void OperationValidator::visit(const ir::operation::Shape &node)
 {
   const auto output_index{node.getOutputs().at(0)};
   if (_ctx.at(output_index).info().isDynamic())
     return;
 
   const auto input_index{node.getInputs().at(0)};
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
+  UNUSED_RELEASE(input_index);
+  OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
 }
 
-void OperationValidator::visit(const ir::operation::Shape &node)
+void OperationValidator::visit(const ir::operation::ResizeBilinear &node)
 {
   const auto output_index{node.getOutputs().at(0)};
+  const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
+
   if (_ctx.at(output_index).info().isDynamic())
+  {
     return;
+  }
+  OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
+  OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
 
-  const auto input_index{node.getInputs().at(0)};
-  UNUSED_RELEASE(input_index);
-  OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
+  auto align_corners = node.param().align_corners;
+  auto half_pixel_centers = node.param().half_pixel_centers;
+
+  OP_REQUIRES(!align_corners || !half_pixel_centers);
 }
 
 void OperationValidator::visit(const ir::operation::Reverse &node)
@@ -972,35 +961,11 @@ void OperationValidator::visit(const ir::operation::While &node)
   // TODO Add to validate with subgraphs
 }
 
-void OperationValidator::visit(const ir::operation::Neg &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  if (_ctx.at(output_index).info().isDynamic())
-    return;
+void OperationValidator::visit(const ir::operation::Neg &node) { checkUnaryOp(node); }
 
-  const auto input_index{node.getInputs().at(0)};
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Log &node) { checkUnaryOp(node); }
 
-void OperationValidator::visit(const ir::operation::Log &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  if (_ctx.at(output_index).info().isDynamic())
-    return;
-
-  const auto input_index{node.getInputs().at(0)};
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
-
-void OperationValidator::visit(const ir::operation::LogicalNot &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  if (_ctx.at(output_index).info().isDynamic())
-    return;
-
-  const auto input_index{node.getInputs().at(0)};
-  OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::LogicalNot &node) { checkUnaryOp(node); }
 
 void OperationValidator::visit(const ir::operation::SquaredDifference &node)
 {
@@ -1118,5 +1083,25 @@ void OperationValidator::visit(const ir::operation::LogSoftmax &node)
 
   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
 }
+
+void OperationValidator::visit(const ir::operation::Quantize &node)
+{
+  VERBOSE(Quantize) << "Configure Quantize operation" << std::endl;
+
+  OP_REQUIRES(node.getInputs().size() == 1);
+  OP_REQUIRES(node.getOutputs().size() == 1);
+
+  const auto input_index{node.getInputs().at(0)};
+  const auto output_index{node.getOutputs().at(0)};
+
+  OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::FLOAT32);
+
+  if (_ctx.at(output_index).info().isDynamic())
+    return;
+
+  OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
+
+  OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
+}
 } // namespace compiler
 } // namespace onert