std::vector<TensorName> _names;
};
+void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs >= 2);
+ assert(num_inputs == _nodes.size());
+
+ loco::Node *target;
+ // do "%0.lhs : %in[0].name" connection
+ target = tensor_names->node(_names[0]);
+ _nodes[0]->lhs(target);
+
+ for (int i = 1; i < num_inputs; ++i)
+ {
+ // do "%i.rhs : %in[i].name" connections
+ target = tensor_names->node(_names[i]);
+ _nodes[i]->rhs(target);
+ }
+}
+
+void TFConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs >= 2);
+ assert(num_inputs == _nodes.size());
+
+ loco::Node *target;
+ // do "%0.lhs : %in[0].name" connection
+ target = tensor_names->node(_names[0]);
+ _nodes[0]->lhs(target);
+
+ for (int i = 1; i < num_inputs; ++i)
+ {
+ // do "%i.rhs : %in[i].name" connections
+ target = tensor_names->node(_names[i]);
+ _nodes[i]->rhs(target);
+ }
+}
+
} // namespace
namespace moco
} // namespace tf
} // namespace moco
-// TODO move this block to upperside
-namespace
-{
-
-void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
-{
- int num_inputs = _names.size();
- assert(num_inputs >= 2);
- assert(num_inputs == _nodes.size());
-
- loco::Node *target;
- // do "%0.lhs : %in[0].name" connection
- target = tensor_names->node(_names[0]);
- _nodes[0]->lhs(target);
-
- for (int i = 1; i < num_inputs; ++i)
- {
- // do "%i.rhs : %in[i].name" connections
- target = tensor_names->node(_names[i]);
- _nodes[i]->rhs(target);
- }
-}
-
-void TFConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
-{
- int num_inputs = _names.size();
- assert(num_inputs >= 2);
- assert(num_inputs == _nodes.size());
-
- loco::Node *target;
- // do "%0.lhs : %in[0].name" connection
- target = tensor_names->node(_names[0]);
- _nodes[0]->lhs(target);
-
- for (int i = 1; i < num_inputs; ++i)
- {
- // do "%i.rhs : %in[i].name" connections
- target = tensor_names->node(_names[i]);
- _nodes[i]->rhs(target);
- }
-}
-
-} // namespace
-
#include "GraphBuilderRegistry.h"
REGISTER_OP_BUILDER(ConcatV2, ConcatV2GraphBuilder)