ASSERT_NE(concat_node->axis(), nullptr);
}
}
-
-namespace
-{
-
-// clang-format off
-const char *concat_03_pbtxtdata = STRING_CONTENT(
-node {
- name: "Input01"
- op: "Placeholder"
- attr {
- key: "dtype"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "shape"
- value {
- shape {
- dim {
- size: 2
- }
- dim {
- size: 3
- }
- }
- }
- }
-}
-node {
- name: "Input02"
- op: "Placeholder"
- attr {
- key: "dtype"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "shape"
- value {
- shape {
- dim {
- size: 2
- }
- dim {
- size: 3
- }
- }
- }
- }
-}
-node {
- name: "Axis"
- op: "Const"
- attr {
- key: "dtype"
- value {
- type: DT_INT32
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
- }
- int_val: -1
- }
- }
- }
-}
-node {
- name: "Concat"
- op: "ConcatV2"
- input: "Input01"
- input: "Input02"
- input: "Axis"
- attr {
- key: "N"
- value {
- i: 2
- }
- }
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "Tidx"
- value {
- type: DT_INT32
- }
- }
-}
-);
-// clang-format on
-
-} // namespace
-
-TEST(TensorFlowImport, concat_03)
-{
- moco::tf::Importer importer;
- moco::tf::ModelSignature signature;
-
- signature.add_input(moco::tf::TensorName("Input01", 0));
- signature.add_input(moco::tf::TensorName("Input02", 0));
- signature.add_output(moco::tf::TensorName("Concat", 0));
-
- tensorflow::GraphDef graph_def;
- EXPECT_TRUE(plier::tf::parse_graphdef(concat_03_pbtxtdata, graph_def));
- // Test "ConcatV2GraphBuilderImpl<ImportTarget::Canonical>"
- {
- // TODO fix indent
- // clang-format off
-
- // what to test: minus axis value validation
- // - there should exist a TensorConcat
- // - lhs() should not be nullptr
- // - rhs() should not be nullptr
- // - axis() should match 2 + (-1), where 2 came from rank of input(s)
-
- using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::Canonical>;
-
- moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
- r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
- moco::tf::Importer importer{&r};
-
- std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
- loco::TensorConcat *concat_node =
- moco::tf::test::find_first_node_bytype<loco::TensorConcat>(graph.get());
-
- ASSERT_NE(concat_node, nullptr);
- ASSERT_NE(concat_node->lhs(), nullptr);
- ASSERT_NE(concat_node->rhs(), nullptr);
- ASSERT_EQ(concat_node->axis(), (2 + (-1)));
-
- // clang-format on
- }
-
- // Validation of axis value is skipped for TF
-}