This will fill GraphBuilder for TFConcatV2 and tests. Also ShapeInference implementation that is required by the test code.
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
std::vector<TensorName> _names;
};
+class TFConcatV2GraphUpdate final : public GraphUpdate
+{
+public:
+ TFConcatV2GraphUpdate(moco::tf::TFConcatV2 *node, std::vector<TensorName> names)
+ : _node(node), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ moco::tf::TFConcatV2 *_node;
+ std::vector<TensorName> _names;
+};
+
void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
{
int num_inputs = _names.size();
}
}
+void TFConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ uint32_t num_values = _names.size() - 1; // exclude axis
+ assert(num_values >= 1);
+
+ for (uint32_t i = 0; i < num_values; ++i)
+ {
+ auto input_node = tensor_names->node(_names[i]);
+ assert(input_node != nullptr);
+ _node->values(i, input_node);
+ }
+ auto axis_node = tensor_names->node(_names[num_values]);
+ assert(axis_node != nullptr);
+ _node->axis(axis_node);
+}
+
} // namespace
namespace moco
void ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
GraphBuilderContext *context) const
{
- /**
- * @note This implementation may change when TFConcatV2 inherits from
- * something like VariableArity.
- */
-
loco::Graph *graph = context->graph();
NodeDefTable *nodedef = context->nodedef();
SymbolTable *tensor_names = context->tensor_names();
UpdateQueue *updates = context->updates();
- // TODO implement
- throw std::runtime_error("NYI ConcatV2GraphBuilderImpl");
+ const int num_inputs = node.input_size() - 1;
+ std::vector<TensorName> input_names;
+ auto concat_node = graph->nodes()->create<TFConcatV2>(num_inputs);
+
+ for (int ni = 0; ni < num_inputs; ++ni)
+ {
+ input_names.push_back(TensorName(node.input(ni)));
+ }
+ // last one is the axis
+ input_names.push_back(TensorName(node.input(num_inputs)));
+
+ // register string-name to the last node as output of concat(s)
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, concat_node);
+
+ auto update = stdex::make_unique<TFConcatV2GraphUpdate>(concat_node, input_names);
+ updates->enroll(std::move(update));
}
} // namespace tf
}
// Test "ConcatV2GraphBuilderImpl<ImportTarget::Tensorflow>"
- // TODO implement test
+ {
+ // what to test:
+ // - there should exist TFConcatV2
+ // - there should be two values
+ // - values(idx) should not be nullptr
+ // - axis() should not be nullptr
+
+ using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
+
+ ASSERT_NE(concat_node, nullptr);
+ ASSERT_EQ(concat_node->num_values(), 2);
+ ASSERT_NE(concat_node->values(0), nullptr);
+ ASSERT_NE(concat_node->values(1), nullptr);
+ ASSERT_NE(concat_node->axis(), nullptr);
+ }
}
namespace
}
// Test "ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>"
- // TODO implement test
+ {
+ // what to test: TFConcatV2 has 3 inputs
+ // - there should exist TFConcatV2
+ // - values(idx) should not be nullptr
+ // - axis() should not be nullptr
+
+ using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
+
+ ASSERT_NE(concat_node, nullptr);
+ ASSERT_EQ(concat_node->num_values(), 3);
+ ASSERT_NE(concat_node->values(0), nullptr);
+ ASSERT_NE(concat_node->values(1), nullptr);
+ ASSERT_NE(concat_node->values(2), nullptr);
+ ASSERT_NE(concat_node->axis(), nullptr);
+ }
}
namespace
// clang-format on
}
- // Test "ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>"
- // TODO implement test
+ // Validation of axis value is skipped for TF
}
return copy_shapedata(value, node);
}
+template <class CONST_CLASS> bool valid_scala_value(CONST_CLASS *node)
+{
+ LOGGER(l);
+
+ auto shapedata = node->template annot<ShapeInferenceData>();
+ assert(shapedata != nullptr);
+
+ if (node->dtype() != loco::DataType::S32)
+ {
+ INFO(l) << "valid_scala_value not S32";
+ return false;
+ }
+
+ auto tensor_shape = shapedata->tensor_shape();
+ if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1))
+ {
+ INFO(l) << "valid_scala_value rank not 0/1 : " << tensor_shape.rank();
+ return false;
+ }
+
+ return true;
+}
+
+template <class CONST_CLASS> int32_t scala_value(CONST_CLASS *node)
+{
+ auto shapedata = node->template annot<ShapeInferenceData>();
+ assert(shapedata != nullptr);
+
+ assert(node->dtype() == loco::DataType::S32);
+
+ auto tensor_shape = shapedata->tensor_shape();
+ assert(tensor_shape.rank() == 0 || tensor_shape.rank() == 1);
+
+ return node->template at<loco::DataType::S32>(0);
+}
+
bool fix_shape(moco::tf::TFConcatV2 *node)
{
- (void)node;
+ LOGGER(l);
+
+ if (node->annot<ShapeInferenceData>() != nullptr)
+ {
+ // shape inference is already done for TFConcatV2
+ INFO(l) << "Fix shape TFConcatV2 already done";
+ return false;
+ }
+ // ConcatData should be null
+ assert(node->annot<ConcatData>() == nullptr);
+
+ // Check shape inference data are all ready
+ // Check shape rank are all same
+ auto value_a = node->values(0);
+ auto value_a_shapedata = value_a->annot<ShapeInferenceData>();
+ if (value_a_shapedata == nullptr)
+ {
+ // shape inference is not ready for this value
+ INFO(l) << "Fix shape TFConcatV2 value 0 shape_data not ready";
+ return false;
+ }
+ uint32_t a_rank = value_a_shapedata->rank();
- throw std::runtime_error("NYI fix_shape TFConcatV2");
+ uint32_t num_values = node->num_values();
+ for (uint32_t ni = 1; ni < num_values; ++ni)
+ {
+ auto value_b = node->values(ni);
+ auto value_b_shapedata = value_b->annot<ShapeInferenceData>();
+ if (value_b_shapedata == nullptr)
+ {
+ // shape inference is not ready for this value
+ INFO(l) << "Fix shape TFConcatV2 value " << ni << " shape_data not ready";
+ return false;
+ }
+ uint32_t b_rank = value_b_shapedata->rank();
+ assert(a_rank == b_rank);
+ }
- return false;
+ // check for axis
+ auto axis_node = node->axis();
+ auto axis_shapedata = axis_node->annot<ShapeInferenceData>();
+ if (axis_shapedata == nullptr)
+ {
+ // shape inference is not ready for axis_node
+ INFO(l) << "Fix shape TFConcatV2 axis shape_data not ready";
+ return false;
+ }
+
+ int32_t axis_value = 0;
+ bool axis_available = false;
+ {
+ // check for axis is TFConst
+ auto tfconst = dynamic_cast<moco::tf::TFConst *>(axis_node);
+ if (tfconst != nullptr)
+ {
+ if (valid_scala_value(tfconst))
+ {
+ axis_value = scala_value(tfconst);
+ axis_available = true;
+ }
+ }
+ }
+ {
+ // check for axis is ConstGen
+ auto constgen = dynamic_cast<loco::ConstGen *>(axis_node);
+ if (constgen != nullptr)
+ {
+ if (valid_scala_value(constgen))
+ {
+ axis_value = scala_value(constgen);
+ axis_available = true;
+ }
+ }
+ }
+ if (!axis_available)
+ {
+ // we cannot find a valid axis value
+ INFO(l) << "Fix shape TFConcatV2 axis_available false";
+ return false;
+ }
+
+ auto concat_data = stdex::make_unique<ConcatData>(axis_value);
+ node->annot(std::move(concat_data));
+
+ uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)a_rank + axis_value;
+
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+ shape_data->rank(a_rank);
+
+ for (uint32_t index = 0; index < a_rank; ++index)
+ {
+ if (value_a_shapedata->dim(index).known())
+ {
+ uint32_t dim = value_a_shapedata->dim(index).value();
+ if (index == axis_absolute)
+ {
+ uint32_t dim_acc = dim;
+ for (uint32_t ni = 1; ni < num_values; ++ni)
+ {
+ auto value_b = node->values(ni);
+ auto value_b_shapedata = value_b->annot<ShapeInferenceData>();
+ assert(value_b_shapedata->dim(index).known());
+ dim_acc += value_b_shapedata->dim(index).value();
+ }
+ dim = dim_acc;
+ }
+ shape_data->dim(index) = dim;
+ }
+ else
+ shape_data->dim(index).unset();
+ }
+ node->annot(std::move(shape_data));
+
+ INFO(l) << "Fix TFConcat shape = " << node->annot<ShapeInferenceData>();
+
+ return true;
}
bool fix_shape(moco::tf::TFConst *node)