#include "TestHelper.h"
#include "Importer.h"
+#include "Knob.h"
+#include "IR/TFConv2D.h"
#include <loco.h>
#include <loco/IR/TensorShape.h>
// clang-format on
} // namespace
-TEST(TensorFlowImport, Conv2D_01)
+namespace
{
- moco::tf::Importer importer;
- moco::tf::ModelSignature signature;
-
- signature.add_output(moco::tf::TensorName("conv2d", 0));
-
- tensorflow::GraphDef graph_def;
- EXPECT_TRUE(parse_graphdef(conv2d_01_pbtxtdata, graph_def));
- std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+void verify_Conv2D_01(loco::Graph *graph)
+{
// test 1.
// loco node : ConstGen - FeatureEncode -- Conv2D - FeatureDecode - Push
// ConstGen - FilterEncode /
}
}
+void verify_TFConv2D_01(loco::Graph *graph)
+{
+ // test 1.
+ // loco node : ConstGen - TFConv2D - Push
+ // ConstGen /
+ loco::Graph::NodeContext *loco_nodes = graph->nodes();
+ loco::Graph::InputContext *loco_inputs = graph->inputs();
+ ASSERT_EQ(loco_inputs->size(), 0);
+ ASSERT_EQ(loco_nodes->size(), 4);
+
+ int idx = 0;
+
+ loco::ConstGen *ifm = dynamic_cast<loco::ConstGen *>(loco_nodes->at(idx++));
+ loco::ConstGen *ker = dynamic_cast<loco::ConstGen *>(loco_nodes->at(idx++));
+ moco::tf::TFConv2D *tfconv2d = dynamic_cast<moco::tf::TFConv2D *>(loco_nodes->at(idx++));
+ loco::Push *push = dynamic_cast<loco::Push *>(loco_nodes->at(idx++));
+
+ ASSERT_NE(ifm, nullptr);
+ ASSERT_NE(ker, nullptr);
+ ASSERT_NE(tfconv2d, nullptr);
+ ASSERT_NE(push, nullptr);
+
+ // check their connection is all OK
+ ASSERT_TRUE(tfconv2d->ifm() == ifm);
+ ASSERT_TRUE(tfconv2d->ker() == ker);
+ ASSERT_TRUE(push->from() == tfconv2d);
+
+ // test 2.
+ // attrs inside TFConv2D
+ ASSERT_EQ(tfconv2d->padding(), "VALID");
+ ASSERT_EQ(tfconv2d->data_layout(), "NHWC");
+ auto strides = tfconv2d->strides();
+ ASSERT_EQ(strides.size(), 4);
+}
+
+} // namespace
+
+TEST(TensorFlowImport, Conv2D_01)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("conv2d", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(parse_graphdef(conv2d_01_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFConv2D>())
+ verify_TFConv2D_01(graph.get());
+ else
+ verify_Conv2D_01(graph.get());
+}
+
namespace
{
// clang-format off
);
} // namespace
-TEST(TensorFlowImport, Conv2D_inception_indexed_tensor_name)
+namespace
{
- moco::tf::Importer importer;
- moco::tf::ModelSignature signature;
-
- signature.add_input(moco::tf::TensorName("input", 0));
- signature.add_output(moco::tf::TensorName("InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D", 0));
-
- tensorflow::GraphDef graph_def;
- EXPECT_TRUE(parse_graphdef(conv2d_inception_pbtxtdata, graph_def));
- std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+void verify_Conv2D_inception_indexed_tensor_name(loco::Graph *graph)
+{
// test 1.
// loco node : Pull - FeatureEncode -- Conv2D - FeatureDecode - Push
// ConstGen - FilterEncode /
ASSERT_EQ(tensor_shape.dim(3).value(), 3); // DEPTH
}
}
+
+void verify_TFConv2D_inception_indexed_tensor_name(loco::Graph *graph)
+{
+ // loco node : Pull - Conv2D - Push
+ // ConstGen /
+ loco::Graph::NodeContext *loco_nodes = graph->nodes();
+ loco::Graph::InputContext *loco_inputs = graph->inputs();
+ ASSERT_EQ(loco_inputs->size(), 1);
+ ASSERT_EQ(loco_nodes->size(), 4);
+
+ int idx = 0;
+
+ loco::Pull *ifm = dynamic_cast<loco::Pull *>(loco_nodes->at(idx++));
+ loco::ConstGen *ker = dynamic_cast<loco::ConstGen *>(loco_nodes->at(idx++));
+ moco::tf::TFConv2D *tfconv2d = dynamic_cast<moco::tf::TFConv2D *>(loco_nodes->at(idx++));
+ loco::Push *push = dynamic_cast<loco::Push *>(loco_nodes->at(idx++));
+
+ ASSERT_NE(ifm, nullptr);
+ ASSERT_NE(ker, nullptr);
+ ASSERT_NE(tfconv2d, nullptr);
+ ASSERT_NE(push, nullptr);
+
+ // check their connection is all OK
+ ASSERT_TRUE(tfconv2d->ifm() == ifm);
+ ASSERT_TRUE(tfconv2d->ker() == ker);
+ ASSERT_TRUE(push->from() == tfconv2d);
+}
+
+} // namespace
+
+TEST(TensorFlowImport, Conv2D_inception_indexed_tensor_name)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("input", 0));
+ signature.add_output(moco::tf::TensorName("InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(parse_graphdef(conv2d_inception_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFConv2D>())
+ verify_TFConv2D_inception_indexed_tensor_name(graph.get());
+ else
+ verify_Conv2D_inception_indexed_tensor_name(graph.get());
+}