#include "TestHelper.h"
+#include "IR/TFIdentity.h"
+#include "Op/Identity.h"
+
#include <loco.h>
#include <plier/tf/TestHelper.h>
TEST(TensorFlowImport, load_model_withio)
{
- moco::tf::Importer importer;
moco::tf::ModelSignature signature;
signature.add_input(moco::tf::TensorName("Placeholder", 0));
tensorflow::GraphDef graph_def;
EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
+
+ using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
- loco::Graph::InputContext *inputs = graph->inputs();
- ASSERT_EQ(inputs->size(), 1);
- loco::GraphInput *input = inputs->at(0);
- ASSERT_EQ(input->dtype(), loco::DataType::FLOAT32);
- loco::Pull *pull = input->node();
+ // what to test:
+ // - import reads Pull
+ // - import reads Forward
+ // - attribute values should match
+
+ auto pull = find_first_node_bytype<loco::Pull>(graph.get());
+ ASSERT_NE(pull, nullptr);
+ auto forward = find_first_node_bytype<loco::Forward>(graph.get());
+ ASSERT_NE(forward, nullptr);
+
ASSERT_EQ(pull->dtype(), loco::DataType::FLOAT32);
ASSERT_EQ(pull->rank(), 4);
loco::Dimension dim1 = 1;
ASSERT_EQ(pull->dim(1).value(), dim2.value());
ASSERT_EQ(pull->dim(2).value(), dim1.value());
ASSERT_EQ(pull->dim(3).value(), dim2.value());
+}
+
+TEST(TensorFlowImport, load_model_withio_tf)
+{
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("output/identity", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
+
+ using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ // TODO add Placeholder
+ r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - import reads Placeholder
+ // - import reads Identity
+ // - attribute values should match
- loco::Graph::OutputContext *outputs = graph->outputs();
- ASSERT_EQ(outputs->size(), 1);
- loco::GraphOutput *output = outputs->at(0);
- ASSERT_EQ(output->dtype(), loco::DataType::FLOAT32);
- loco::Push *push = output->node();
- // Currently we don't know the shape of output node(s)
- // ASSERT_EQ(push->rank(), 4);
- // ASSERT_EQ(push->dim(0).value(), dim1.value());
- // ASSERT_EQ(push->dim(1).value(), dim2.value());
- // ASSERT_EQ(push->dim(2).value(), dim1.value());
- // ASSERT_EQ(push->dim(3).value(), dim2.value());
-
- loco::Graph::NodeContext *nodes = graph->nodes();
- ASSERT_EQ(nodes->size(), 3);
- loco::Pull *node0 = dynamic_cast<loco::Pull *>(nodes->at(0));
- ASSERT_EQ(node0, pull);
- loco::Push *node2 = dynamic_cast<loco::Push *>(nodes->at(2));
- ASSERT_EQ(node2, push);
- loco::Forward *node1 = dynamic_cast<loco::Forward *>(nodes->at(1));
- ASSERT_NE(node1, nullptr);
+ auto tfidentity = find_first_node_bytype<moco::tf::TFIdentity>(graph.get());
+ ASSERT_NE(tfidentity, nullptr);
+ ASSERT_NE(tfidentity->input(), nullptr);
}