From 9c61acda445d1641bd4d468d35401dd16c23ee1a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 20 Jun 2019 07:54:20 +0900 Subject: [PATCH] [moco/tf] arrange concat validate codes (#3881) This will move(arrange) codes in build to validate for ConcatV2 Signed-off-by: SaeHie Park --- contrib/moco/lib/frontend/tf/src/Op/Concat.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/contrib/moco/lib/frontend/tf/src/Op/Concat.cpp b/contrib/moco/lib/frontend/tf/src/Op/Concat.cpp index 69d20e0..e787f71 100644 --- a/contrib/moco/lib/frontend/tf/src/Op/Concat.cpp +++ b/contrib/moco/lib/frontend/tf/src/Op/Concat.cpp @@ -60,7 +60,14 @@ private: std::vector _names; }; -bool ConcatV2GraphBuilder::validate(const tensorflow::NodeDef &node) const { return true; } +bool ConcatV2GraphBuilder::validate(const tensorflow::NodeDef &node) const +{ + // Concat node SHOULD have 3 or more inputs, that is 2 + axis + const int num_inputs = node.input_size() - 1; + assert(num_inputs >= 2); + assert(num_inputs == get_int_attr(node, "N")); + return (num_inputs >= 2) && (num_inputs == get_int_attr(node, "N")); +} void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const @@ -91,11 +98,7 @@ void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node, // %(N-1).lhs = %(N-2) // %N.lhs = %(N-1) - // Queue node input update - // Concat node SHOULD have 3 or more inputs, that is 2 + axis const int num_inputs = node.input_size() - 1; - assert(num_inputs >= 2); - assert(num_inputs == get_int_attr(node, "N")); std::vector concat_nodes; std::vector input_names; @@ -103,6 +106,7 @@ void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node, auto concat_node = graph->nodes()->create(); loco::TensorConcat *last_concat = concat_node; + // Queue node input update concat_nodes.push_back(concat_node); // used for LHS of connection -> %0 concat_nodes.push_back(concat_node); // used for RHS of connection -> %1 input_names.push_back(TensorName(node.input(0))); // for first concat (%0) LHS -- 2.7.4