From 7db638f738c934c2c1e971c53ce5f8bdc7e345fb Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9E=A5=EC=A7=80=EC=84=AD/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 1 Apr 2019 09:58:51 +0900 Subject: [PATCH] [neurun] Add validation of ComparisonNode (#4901) This commit adds validation of ComparisonNode. Signed-off-by: jiseob.jang --- runtimes/neurun/core/src/compiler/OperationValidator.cc | 14 ++++++++++++++ runtimes/neurun/core/src/compiler/OperationValidator.h | 1 + 2 files changed, 15 insertions(+) diff --git a/runtimes/neurun/core/src/compiler/OperationValidator.cc b/runtimes/neurun/core/src/compiler/OperationValidator.cc index da00bb1..279651d 100644 --- a/runtimes/neurun/core/src/compiler/OperationValidator.cc +++ b/runtimes/neurun/core/src/compiler/OperationValidator.cc @@ -40,6 +40,20 @@ void OperationValidator::visit(const model::operation::CastNode &node) assert(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); } +void OperationValidator::visit(const model::operation::ComparisonNode &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto lhs_index{node.getInputs().at(model::operation::ComparisonNode::Input::INPUT0)}; + const auto rhs_index{node.getInputs().at(model::operation::ComparisonNode::Input::INPUT1)}; + + UNUSED_RELEASE(output_index); + UNUSED_RELEASE(lhs_index); + UNUSED_RELEASE(rhs_index); + + assert(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type()); + assert(_ctx.at(output_index).typeInfo().type() == model::operand::DataType::TENSOR_BOOL8); +} + void OperationValidator::visit(const model::operation::SoftmaxNode &node) { VERBOSE(Softmax) << "Configure SOFTMAX operation" << std::endl; diff --git a/runtimes/neurun/core/src/compiler/OperationValidator.h b/runtimes/neurun/core/src/compiler/OperationValidator.h index 346bcb1..7b9a3d3 100644 --- a/runtimes/neurun/core/src/compiler/OperationValidator.h +++ b/runtimes/neurun/core/src/compiler/OperationValidator.h @@ -42,6 +42,7 @@ public: public: virtual void visit(const model::operation::CastNode &node) override; + virtual void visit(const model::operation::ComparisonNode &node) override; virtual void visit(const model::operation::SoftmaxNode &node) override; virtual void visit(const model::operation::PermuteNode &node) override; virtual void visit(const model::operation::ReduceSumNode &node) override; -- 2.7.4