[moco-tf] Refactor TF node graph builder test part 1 (#8296)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 18 Oct 2019 07:10:48 +0000 (16:10 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 18 Oct 2019 07:10:48 +0000 (16:10 +0900)
This will refactor test NOT to use Importer module of node graph builder from Add to FusedBatchNorm

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Op/Add.test.cpp
compiler/moco-tf/src/Op/AvgPool.test.cpp
compiler/moco-tf/src/Op/Concat.test.cpp
compiler/moco-tf/src/Op/Const.test.cpp
compiler/moco-tf/src/Op/Conv2D.test.cpp
compiler/moco-tf/src/Op/Conv2DBackpropInput.test.cpp
compiler/moco-tf/src/Op/DepthwiseConv2dNative.test.cpp
compiler/moco-tf/src/Op/FusedBatchNorm.test.cpp

index dc53f37..ca06120 100644 (file)
  * limitations under the License.
  */
 
+#include "Add.h"
 #include "TestHelper.h"
 
-#include "Importer.h"
-
-#include "IR/TFAdd.h"
-
-#include <loco.h>
-#include <plier/tf/TestHelper.h>
-
 #include <gtest/gtest.h>
 
-#include <tensorflow/core/framework/graph.pb.h>
-
-#include <cstring>
-#include <memory>
-
 using namespace moco::tf::test;
 
 namespace
 {
 // clang-format off
 const char *add_basic_pbtxt = STRING_CONTENT(
-node {
-  name: "input_01"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 1
-          }
-          dim {
-            size: 3
-          }
-          dim {
-            size: 4
-          }
-        }
-        float_val: 1.0
-      }
-    }
-  }
-}
-node {
-  name: "input_02"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 1
-          }
-          dim {
-            size: 3
-          }
-          dim {
-            size: 4
-          }
-        }
-        float_val: 2.0
-      }
-    }
-  }
-}
-node {
   name: "ADD_01"
   op: "Add"
   input: "input_01"
@@ -107,7 +35,6 @@ node {
       type: DT_FLOAT
     }
   }
-}
 );
 // clang-format on
 
@@ -115,22 +42,17 @@ node {
 
 TEST(TensorFlowImport, tf_add_basic)
 {
-  // load graph
-  moco::tf::Importer importer;
-  moco::tf::ModelSignature signature;
-  signature.add_output(moco::tf::TensorName("ADD_01", 0));
+  TFNodeBuildTester tester;
+  moco::tf::AddGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
 
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(add_basic_pbtxt, graph_def));
-  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+  EXPECT_TRUE(plier::tf::parse_nodedef(add_basic_pbtxt, nodedef));
 
   // what to test:
   // - TFAdd node should exist
   // - both inputs x() and y() should not be null
 
-  auto add_node = moco::tf::test::find_first_node_bytype<moco::tf::TFAdd>(graph.get());
-
-  ASSERT_NE(add_node, nullptr);
-  ASSERT_NE(add_node->x(), nullptr);
-  ASSERT_NE(add_node->y(), nullptr);
+  tester.inputs({"input_01", "input_02"});
+  tester.output("ADD_01");
+  tester.run(nodedef, graphbuilder);
 }
index bf247f3..83fb0b1 100644 (file)
  */
 
 #include "AvgPool.h"
-
 #include "IR/TFAvgPool.h"
-
 #include "TestHelper.h"
 
-#include "Importer.h"
-
-#include <loco.h>
-#include <loco/IR/TensorShape.h>
-#include <loco/IR/FeatureShape.h>
-#include <plier/tf/TestHelper.h>
-
 #include <gtest/gtest.h>
 
-#include <memory>
-
-using namespace moco::tf;
 using namespace moco::tf::test;
 
 namespace
 {
 // clang-format off
 const char *avgpool_01_pbtxtdata = STRING_CONTENT(
-node {
-  name: "const/float"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 1
-          }
-          dim {
-            size: 3
-          }
-          dim {
-            size: 3
-          }
-          dim {
-            size: 1
-          }
-        }
-        float_val: 1.1
-      }
-    }
-  }
-}
-node {
   name: "avgpool"
   op: "AvgPool"
   input: "const/float"
@@ -115,7 +69,6 @@ node {
       }
     }
   }
-}
 );
 // clang-format on
 
@@ -123,36 +76,25 @@ node {
 
 TEST(TensorFlowImport, AvgPool_01)
 {
-  moco::tf::Importer importer;
-  moco::tf::ModelSignature signature;
-
-  signature.add_output(moco::tf::TensorName("avgpool", 0));
-
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(avgpool_01_pbtxtdata, graph_def));
-
-  {
-    // what to test:
-    // - there should exist TFAvgPool
-    // - attributes value should match
-    moco::tf::Importer importer;
-
-    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
-    moco::tf::TFAvgPool *avgpool_node =
-        moco::tf::test::find_first_node_bytype<moco::tf::TFAvgPool>(graph.get());
-    ASSERT_NE(avgpool_node, nullptr);
-
-    loco::Node *previous_node = avgpool_node->value();
-    auto following_nodes = loco::succs(avgpool_node);
-    ASSERT_EQ(following_nodes.size(), 1);
-    loco::Node *following_node = *following_nodes.begin();
-    ASSERT_NE(following_node, nullptr);
-
-    // attrs inside TFAvgPool2D
-    ASSERT_EQ(avgpool_node->data_layout(), "NHWC");
-    ASSERT_EQ(avgpool_node->padding(), "VALID");
-    ASSERT_EQ(avgpool_node->ksize(), std::vector<int64_t>({1, 2, 3, 1}));
-    ASSERT_EQ(avgpool_node->strides(), std::vector<int64_t>({1, 3, 2, 1}));
-  }
+  TFNodeBuildTester tester;
+  moco::tf::AvgPoolGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
+
+  EXPECT_TRUE(plier::tf::parse_nodedef(avgpool_01_pbtxtdata, nodedef));
+
+  // what to test:
+  // - there should exist TFAvgPool
+  // - input should exist
+  // - attributes value should match
+
+  tester.inputs({"const/float"});
+  tester.output("avgpool");
+  tester.run(nodedef, graphbuilder);
+
+  auto test_node = dynamic_cast<moco::tf::TFAvgPool *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->data_layout(), "NHWC");
+  ASSERT_EQ(test_node->padding(), "VALID");
+  ASSERT_EQ(test_node->ksize(), std::vector<int64_t>({1, 2, 3, 1}));
+  ASSERT_EQ(test_node->strides(), std::vector<int64_t>({1, 3, 2, 1}));
 }
index ec3adaf..fe6ade1 100644 (file)
  */
 
 #include "Concat.h"
-
 #include "IR/TFConcatV2.h"
-
 #include "TestHelper.h"
 
-#include "Importer.h"
-
-#include <loco.h>
-#include <plier/tf/TestHelper.h>
-
 #include <gtest/gtest.h>
 
-using namespace moco::tf;
 using namespace moco::tf::test;
 
 namespace
@@ -35,92 +27,6 @@ namespace
 
 // clang-format off
 const char *concat_01_pbtxtdata = STRING_CONTENT(
-node {
-  name: "Input01"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 2
-          }
-          dim {
-            size: 3
-          }
-        }
-        float_val: 1
-        float_val: 2
-        float_val: 3
-        float_val: 4
-        float_val: 5
-        float_val: 6
-      }
-    }
-  }
-}
-node {
-  name: "Input02"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 2
-          }
-          dim {
-            size: 3
-          }
-        }
-        float_val: 7
-        float_val: 8
-        float_val: 9
-        float_val: 10
-        float_val: 11
-        float_val: 12
-      }
-    }
-  }
-}
-node {
-  name: "Axis"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_INT32
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_INT32
-        tensor_shape {
-        }
-        int_val: 0
-      }
-    }
-  }
-}
-node {
   name: "Concat"
   op: "ConcatV2"
   input: "Input01"
@@ -144,7 +50,6 @@ node {
       type: DT_INT32
     }
   }
-}
 );
 // clang-format on
 
@@ -152,34 +57,25 @@ node {
 
 TEST(TensorFlowImport, concat_01)
 {
-  moco::tf::Importer importer;
-  moco::tf::ModelSignature signature;
-
-  moco::tf::TensorName output("Concat", 0);
-  signature.add_output(output);
+  TFNodeBuildTester tester;
+  moco::tf::ConcatV2GraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
 
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(concat_01_pbtxtdata, graph_def));
-  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+  EXPECT_TRUE(plier::tf::parse_nodedef(concat_01_pbtxtdata, nodedef));
 
-  {
-    // what to test:
-    // - there should exist TFConcatV2
-    // - there should be two values
-    // - values(idx) should not be nullptr
-    // - axis() should not be nullptr
-    moco::tf::Importer importer;
+  // what to test:
+  // - there should exist TFConcatV2
+  // - there should be two values
+  // - values(idx) should not be nullptr
+  // - axis() should not be nullptr
 
-    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+  tester.inputs({"Input01", "Input02", "Axis"});
+  tester.output("Concat");
+  tester.run(nodedef, graphbuilder);
 
-    auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
-
-    ASSERT_NE(concat_node, nullptr);
-    ASSERT_EQ(concat_node->num_values(), 2);
-    ASSERT_NE(concat_node->values(0), nullptr);
-    ASSERT_NE(concat_node->values(1), nullptr);
-    ASSERT_NE(concat_node->axis(), nullptr);
-  }
+  auto test_node = dynamic_cast<moco::tf::TFConcatV2 *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->num_values(), 2);
 }
 
 namespace
@@ -187,124 +83,6 @@ namespace
 
 // clang-format off
 const char *concat_02_pbtxtdata = STRING_CONTENT(
-node {
-  name: "Input01"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 2
-          }
-          dim {
-            size: 3
-          }
-        }
-        float_val: 1
-        float_val: 2
-        float_val: 3
-        float_val: 4
-        float_val: 5
-        float_val: 6
-      }
-    }
-  }
-}
-node {
-  name: "Input02"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 2
-          }
-          dim {
-            size: 3
-          }
-        }
-        float_val: 7
-        float_val: 8
-        float_val: 9
-        float_val: 10
-        float_val: 11
-        float_val: 12
-      }
-    }
-  }
-}
-node {
-  name: "Input03"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 2
-          }
-          dim {
-            size: 3
-          }
-        }
-        float_val: 13
-        float_val: 14
-        float_val: 15
-        float_val: 16
-        float_val: 17
-        float_val: 18
-      }
-    }
-  }
-}
-node {
-  name: "Axis"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_INT32
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_INT32
-        tensor_shape {
-        }
-        int_val: 0
-      }
-    }
-  }
-}
-node {
   name: "Concat"
   op: "ConcatV2"
   input: "Input01"
@@ -329,7 +107,6 @@ node {
       type: DT_INT32
     }
   }
-}
 );
 // clang-format on
 
@@ -337,31 +114,22 @@ node {
 
 TEST(TensorFlowImport, concat_02)
 {
-  moco::tf::Importer importer;
-  moco::tf::ModelSignature signature;
-
-  moco::tf::TensorName output("Concat", 0);
-  signature.add_output(output);
+  TFNodeBuildTester tester;
+  moco::tf::ConcatV2GraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
 
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(concat_02_pbtxtdata, graph_def));
+  EXPECT_TRUE(plier::tf::parse_nodedef(concat_02_pbtxtdata, nodedef));
 
-  {
-    // what to test: TFConcatV2 has 3 inputs
-    // - there should exist TFConcatV2
-    // - values(idx) should not be nullptr
-    // - axis() should not be nullptr
-    moco::tf::Importer importer;
+  // what to test: TFConcatV2 has 3 inputs
+  // - there should exist TFConcatV2
+  // - values(idx) should not be nullptr
+  // - axis() should not be nullptr
 
-    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+  tester.inputs({"Input01", "Input02", "Input03", "Axis"});
+  tester.output("Concat");
+  tester.run(nodedef, graphbuilder);
 
-    auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
-
-    ASSERT_NE(concat_node, nullptr);
-    ASSERT_EQ(concat_node->num_values(), 3);
-    ASSERT_NE(concat_node->values(0), nullptr);
-    ASSERT_NE(concat_node->values(1), nullptr);
-    ASSERT_NE(concat_node->values(2), nullptr);
-    ASSERT_NE(concat_node->axis(), nullptr);
-  }
+  auto test_node = dynamic_cast<moco::tf::TFConcatV2 *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->num_values(), 3);
 }
index 3d9d60c..b042be8 100644 (file)
 
 #include "Const.h"
 #include "TestHelper.h"
-
-#include "Importer.h"
-
 #include "IR/TFConst.h"
 
-#include <loco.h>
-#include <plier/tf/TestHelper.h>
-
 #include <gtest/gtest.h>
 
-#include <cstring>
-#include <memory>
-
-using namespace moco::tf;
 using namespace moco::tf::test;
 
 namespace
@@ -39,7 +29,6 @@ namespace
 
 // clang-format off
 const char *const_float_01_pbtxtdata = STRING_CONTENT(
-node {
   name: "const/float"
   op: "Const"
   attr {
@@ -70,7 +59,6 @@ node {
       }
     }
   }
-}
 );
 // clang-format on
 
@@ -78,30 +66,29 @@ node {
 
 TEST(TensorFlowImport, const_float_01)
 {
-  moco::tf::ModelSignature signature;
-
-  signature.add_output(moco::tf::TensorName("const/float", 0));
-
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(const_float_01_pbtxtdata, graph_def));
-
-  {
-    moco::tf::Importer importer;
-
-    auto graph = importer.import(signature, graph_def);
-
-    moco::tf::TFConst *node0 =
-        moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
-    ASSERT_NE(node0, nullptr);
-
-    ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 3.3f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 4.4f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
-  }
+  TFNodeBuildTester tester;
+  moco::tf::ConstGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
+
+  EXPECT_TRUE(plier::tf::parse_nodedef(const_float_01_pbtxtdata, nodedef));
+
+  // what to test:
+  // - there should exist TFConst
+  // - values should match
+
+  tester.inputs({});
+  tester.output("const/float");
+  tester.run(nodedef, graphbuilder);
+
+  auto test_node = dynamic_cast<moco::tf::TFConst *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->size<loco::DataType::FLOAT32>(), 6);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(0), 1.1f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(1), 2.2f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(2), 3.3f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(3), 4.4f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(4), 5.5f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(5), 6.6f);
 }
 
 namespace
@@ -110,7 +97,6 @@ namespace
 
 // clang-format off
 const char *const_float_02_pbtxtdata = STRING_CONTENT(
-node {
   name: "const/float"
   op: "Const"
   attr {
@@ -136,7 +122,6 @@ node {
       }
     }
   }
-}
 );
 // clang-format on
 
@@ -144,30 +129,29 @@ node {
 
 TEST(TensorFlowImport, const_float_02)
 {
-  moco::tf::ModelSignature signature;
-
-  signature.add_output(moco::tf::TensorName("const/float", 0));
-
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(const_float_02_pbtxtdata, graph_def));
-
-  {
-    moco::tf::Importer importer;
-
-    auto graph = importer.import(signature, graph_def);
-
-    moco::tf::TFConst *node0 =
-        moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
-    ASSERT_NE(node0, nullptr);
-
-    ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 1.1f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 1.1f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 1.1f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 1.1f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 1.1f);
-  }
+  TFNodeBuildTester tester;
+  moco::tf::ConstGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
+
+  EXPECT_TRUE(plier::tf::parse_nodedef(const_float_02_pbtxtdata, nodedef));
+
+  // what to test:
+  // - there should exist TFConst
+  // - values should match
+
+  tester.inputs({});
+  tester.output("const/float");
+  tester.run(nodedef, graphbuilder);
+
+  auto test_node = dynamic_cast<moco::tf::TFConst *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->size<loco::DataType::FLOAT32>(), 6);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(0), 1.1f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(1), 1.1f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(2), 1.1f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(3), 1.1f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(4), 1.1f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(5), 1.1f);
 }
 
 namespace
@@ -177,7 +161,6 @@ namespace
 
 // clang-format off
 const char *const_float_03_pbtxtdata = STRING_CONTENT(
-node {
   name: "const/float"
   op: "Const"
   attr {
@@ -203,7 +186,6 @@ node {
       }
     }
   }
-}
 );
 // clang-format on
 
@@ -211,30 +193,29 @@ node {
 
 TEST(TensorFlowImport, const_float_03)
 {
-  moco::tf::ModelSignature signature;
-
-  signature.add_output(moco::tf::TensorName("const/float", 0));
-
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(const_float_03_pbtxtdata, graph_def));
-
-  {
-    moco::tf::Importer importer;
-
-    auto graph = importer.import(signature, graph_def);
-
-    moco::tf::TFConst *node0 =
-        moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
-    ASSERT_NE(node0, nullptr);
-
-    ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 3.3f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 4.4f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
-  }
+  TFNodeBuildTester tester;
+  moco::tf::ConstGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
+
+  EXPECT_TRUE(plier::tf::parse_nodedef(const_float_03_pbtxtdata, nodedef));
+
+  // what to test:
+  // - there should exist TFConst
+  // - values should match
+
+  tester.inputs({});
+  tester.output("const/float");
+  tester.run(nodedef, graphbuilder);
+
+  auto test_node = dynamic_cast<moco::tf::TFConst *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->size<loco::DataType::FLOAT32>(), 6);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(0), 1.1f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(1), 2.2f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(2), 3.3f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(3), 4.4f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(4), 5.5f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(5), 6.6f);
 }
 
 namespace
@@ -243,7 +224,6 @@ namespace
 
 // clang-format off
 const char *const_float_04_pbtxtdata = STRING_CONTENT(
-node {
   name: "const/float"
   op: "Const"
   attr {
@@ -270,7 +250,6 @@ node {
       }
     }
   }
-}
 );
 // clang-format on
 
@@ -278,29 +257,29 @@ node {
 
 TEST(TensorFlowImport, const_float_04)
 {
-  moco::tf::ModelSignature signature;
-
-  signature.add_output(moco::tf::TensorName("const/float", 0));
-
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(const_float_04_pbtxtdata, graph_def));
-
-  {
-    moco::tf::Importer importer;
-
-    auto graph = importer.import(signature, graph_def);
-
-    moco::tf::TFConst *node0 =
-        moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
-
-    ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 2.2f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 2.2f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 2.2f);
-    ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 2.2f);
-  }
+  TFNodeBuildTester tester;
+  moco::tf::ConstGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
+
+  EXPECT_TRUE(plier::tf::parse_nodedef(const_float_04_pbtxtdata, nodedef));
+
+  // what to test:
+  // - there should exist TFConst
+  // - values should match
+
+  tester.inputs({});
+  tester.output("const/float");
+  tester.run(nodedef, graphbuilder);
+
+  auto test_node = dynamic_cast<moco::tf::TFConst *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->size<loco::DataType::FLOAT32>(), 6);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(0), 1.1f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(1), 2.2f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(2), 2.2f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(3), 2.2f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(4), 2.2f);
+  ASSERT_EQ(test_node->at<loco::DataType::FLOAT32>(5), 2.2f);
 }
 
 namespace
@@ -309,7 +288,6 @@ namespace
 
 // clang-format off
 const char *const_int32_04_pbtxtdata = STRING_CONTENT(
-node {
   name: "const/int"
   op: "Const"
   attr {
@@ -336,7 +314,6 @@ node {
       }
     }
   }
-}
 );
 // clang-format on
 
@@ -344,39 +321,27 @@ node {
 
 TEST(TensorFlowImport, const_int32_04)
 {
-  moco::tf::ModelSignature signature;
-
-  signature.add_output(moco::tf::TensorName("const/int", 0));
-
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(const_int32_04_pbtxtdata, graph_def));
-
-// TODO Re-enable this
-#if 0
-  loco::Graph::OutputContext *outputs = graph->outputs();
-  ASSERT_EQ(outputs->size(), 1);
-  loco::GraphOutput *output = outputs->at(0);
-  loco::Push *push = output->node();
-
-  loco::Graph::NodeContext *nodes = graph->nodes();
-  ASSERT_EQ(nodes->size(), 2);
-#endif
-
-  {
-    moco::tf::Importer importer;
-
-    auto graph = importer.import(signature, graph_def);
-
-    moco::tf::TFConst *node0 =
-        moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
-    ASSERT_NE(node0, nullptr);
-
-    ASSERT_EQ(node0->size<loco::DataType::S32>(), 6);
-    ASSERT_EQ(node0->at<loco::DataType::S32>(0), 1);
-    ASSERT_EQ(node0->at<loco::DataType::S32>(1), 2);
-    ASSERT_EQ(node0->at<loco::DataType::S32>(2), 2);
-    ASSERT_EQ(node0->at<loco::DataType::S32>(3), 2);
-    ASSERT_EQ(node0->at<loco::DataType::S32>(4), 2);
-    ASSERT_EQ(node0->at<loco::DataType::S32>(5), 2);
-  }
+  TFNodeBuildTester tester;
+  moco::tf::ConstGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
+
+  EXPECT_TRUE(plier::tf::parse_nodedef(const_int32_04_pbtxtdata, nodedef));
+
+  // what to test:
+  // - there should exist TFConst
+  // - values should match
+
+  tester.inputs({});
+  tester.output("const/int");
+  tester.run(nodedef, graphbuilder);
+
+  auto test_node = dynamic_cast<moco::tf::TFConst *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->size<loco::DataType::S32>(), 6);
+  ASSERT_EQ(test_node->at<loco::DataType::S32>(0), 1);
+  ASSERT_EQ(test_node->at<loco::DataType::S32>(1), 2);
+  ASSERT_EQ(test_node->at<loco::DataType::S32>(2), 2);
+  ASSERT_EQ(test_node->at<loco::DataType::S32>(3), 2);
+  ASSERT_EQ(test_node->at<loco::DataType::S32>(4), 2);
+  ASSERT_EQ(test_node->at<loco::DataType::S32>(5), 2);
 }
index aaccb69..8d6765f 100644 (file)
  */
 
 #include "Conv2D.h"
-
 #include "TestHelper.h"
-
-#include "Importer.h"
 #include "IR/TFConv2D.h"
 
-#include <loco.h>
-#include <loco/IR/TensorShape.h>
-#include <loco/IR/FeatureShape.h>
-#include <plier/tf/TestHelper.h>
-
 #include <gtest/gtest.h>
 
-#include <memory>
-
-using namespace moco::tf;
 using namespace moco::tf::test;
 
 namespace
 {
 // clang-format off
 const char *conv2d_01_pbtxtdata = STRING_CONTENT(
-node {
-  name: "ifm"
-  op: "Const"
-  attr { key: "dtype" value { type: DT_FLOAT } }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim { size: 1 }
-          dim { size: 4 }
-          dim { size: 4 }
-          dim { size: 3 }
-        }
-        float_val: 1.1
-      }
-    }
-  }
-}
-node {
-  name: "ker"
-  op: "Const"
-  attr { key: "dtype" value { type: DT_FLOAT } }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim { size: 2 }
-          dim { size: 2 }
-          dim { size: 3 }
-          dim { size: 100 }
-        }
-        float_val: 1.1
-      }
-    }
-  }
-}
-node {
   name: "conv2d"
   op: "Conv2D"
   input: "ifm"
@@ -87,102 +35,41 @@ node {
   attr { key: "dilations" value { list { i: 1 i: 1 i: 1 i: 1 } } }
   attr { key: "padding" value { s: "VALID" } }
   attr { key: "strides" value { list { i: 1 i: 2 i: 3 i: 1 } } }
-  attr { key: "use_cudnn_on_gpu" value { b: false } }
-}
 );
 // clang-format on
 } // namespace
 
-namespace
+TEST(TensorFlowImport, Conv2D_01)
 {
+  TFNodeBuildTester tester;
+  moco::tf::Conv2DGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
+
+  EXPECT_TRUE(plier::tf::parse_nodedef(conv2d_01_pbtxtdata, nodedef));
 
-void verify_TFConv2D_01(loco::Graph *graph)
-{
   // what to test:
-  // - Con2D node should exist
+  // - Conv2D node should exist
   // - ifm() should not be nullptr
   // - ker() should not be nullptr
   // - attribute values should match
 
-  // loco node : ConstGen - TFConv2D - Push
-  //             ConstGen /
-  moco::tf::TFConv2D *tfconv2d = moco::tf::test::find_first_node_bytype<moco::tf::TFConv2D>(graph);
-  ASSERT_NE(tfconv2d, nullptr);
-  ASSERT_NE(tfconv2d->input(), nullptr);
-  ASSERT_NE(tfconv2d->filter(), nullptr);
+  tester.inputs({"ifm", "ker"});
+  tester.output("conv2d");
+  tester.run(nodedef, graphbuilder);
 
-  // attrs inside TFConv2D
-  ASSERT_EQ(tfconv2d->padding(), "VALID");
-  ASSERT_EQ(tfconv2d->data_layout(), "NHWC");
-  auto strides = tfconv2d->strides();
+  auto test_node = dynamic_cast<moco::tf::TFConv2D *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->padding(), "VALID");
+  ASSERT_EQ(test_node->data_layout(), "NHWC");
+  auto strides = test_node->strides();
   ASSERT_EQ(strides.size(), 4);
   // TODO add verify dilation
 }
 
-} // namespace
-
-TEST(TensorFlowImport, Conv2D_01)
-{
-  moco::tf::Importer importer;
-  moco::tf::ModelSignature signature;
-
-  signature.add_output(moco::tf::TensorName("conv2d", 0));
-
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(conv2d_01_pbtxtdata, graph_def));
-
-  {
-    moco::tf::Importer importer;
-
-    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
-    verify_TFConv2D_01(graph.get());
-  }
-}
-
 namespace
 {
 // clang-format off
 const char *conv2d_inception_pbtxtdata = STRING_CONTENT(
-node {
-  name: "input"
-  op: "Placeholder"
-  attr {
-    key: "dtype" value { type: DT_FLOAT }
-  }
-  attr {
-    key: "shape"
-    value {
-      shape {
-        dim { size: 1 }
-        dim { size: 299 }
-        dim { size: 299 }
-        dim { size: 3 }
-      }
-    }
-  }
-}
-node {
-  name: "InceptionV3/Conv2d_1a_3x3/weights/read/_3__cf__3"
-  op: "Const"
-  attr { key: "dtype" value { type: DT_FLOAT } }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim { size: 3 }
-          dim { size: 3 }
-          dim { size: 3 }
-          dim { size: 32 }
-        }
-        float_val: 1.1
-      }
-    }
-  }
-}
-node {
   name: "InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D"
   op: "Conv2D"
   input: "input:0"
@@ -211,52 +98,23 @@ node {
       list { i: 1 i: 2 i: 2 i: 1 }
     }
   }
-  attr {
-    key: "use_cudnn_on_gpu"
-    value { b: true }
-  }
-}
 );
 } // namespace
 
-namespace
+TEST(TensorFlowImport, Conv2D_inception_indexed_tensor_name)
 {
+  TFNodeBuildTester tester;
+  moco::tf::Conv2DGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
+
+  EXPECT_TRUE(plier::tf::parse_nodedef(conv2d_inception_pbtxtdata, nodedef));
 
-void verify_TFConv2D_inception_indexed_tensor_name(loco::Graph *graph)
-{
   // what to test: name with ':0' should be treated correctly
-  // - Con2D node should exist
+  // - Conv2D node should exist
   // - ifm() should not be nullptr
   // - ker() should not be nullptr
 
-  // loco node : Pull - Conv2D - Push
-  //         ConstGen /
-  moco::tf::TFConv2D *tfconv2d =
-      moco::tf::test::find_first_node_bytype<moco::tf::TFConv2D>(graph);
-  ASSERT_NE(tfconv2d, nullptr);
-  ASSERT_NE(tfconv2d->input(), nullptr);
-  ASSERT_NE(tfconv2d->filter(), nullptr);
-}
-
-} // namespace
-
-TEST(TensorFlowImport, Conv2D_inception_indexed_tensor_name)
-{
-  moco::tf::Importer importer;
-  moco::tf::ModelSignature signature;
-
-  signature.add_input(moco::tf::TensorName("input", 0));
-  signature.add_output(moco::tf::TensorName("InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D", 0));
-
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(conv2d_inception_pbtxtdata, graph_def));
-  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
-  {
-    moco::tf::Importer importer;
-
-    std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
-    verify_TFConv2D_inception_indexed_tensor_name(graph.get());
-  }
+  tester.inputs({"input", "InceptionV3/Conv2d_1a_3x3/weights/read/_3__cf__3"});
+  tester.output("InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D");
+  tester.run(nodedef, graphbuilder);
 }
index 988eb1b..087b7a9 100644 (file)
  * limitations under the License.
  */
 
+#include "Conv2DBackpropInput.h"
 #include "TestHelper.h"
-
-#include "Importer.h"
 #include "IR/TFConv2DBackpropInput.h"
 
-#include <loco/IR/TensorShape.h>
-#include <loco/IR/FeatureShape.h>
-#include <plier/tf/TestHelper.h>
-
 #include <gtest/gtest.h>
 
-#include <memory>
-
 using namespace moco::tf::test;
 
 namespace
 {
 // clang-format off
 const char *conv2d_backprop_input_01_pbtxtdata = STRING_CONTENT(
-node {
-  name: "ifm"
-  op: "Placeholder"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "shape"
-    value {
-      shape {
-        dim {
-          size: 1
-        }
-        dim {
-          size: 8
-        }
-        dim {
-          size: 6
-        }
-        dim {
-          size: 3
-        }
-      }
-    }
-  }
-}
-node {
-  name: "weights"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 3
-          }
-          dim {
-            size: 3
-          }
-          dim {
-            size: 2
-          }
-          dim {
-            size: 3
-          }
-        }
-        float_val: 1.0
-      }
-    }
-  }
-}
-node {
-  name: "outshape"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_INT32
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_INT32
-        tensor_shape {
-          dim {
-            size: 4
-          }
-        }
-        int_val: 1
-        int_val: 16
-        int_val: 12
-        int_val: 2
-      }
-    }
-  }
-}
-node {
   name: "ofm"
   op: "Conv2DBackpropInput"
   input: "outshape"
@@ -168,37 +71,29 @@ node {
       }
     }
   }
-}
 );
 // clang-format on
 } // namespace
 
 TEST(TensorFlowImport, conv2d_backprop_input_01)
 {
-  moco::tf::Importer importer;
-  moco::tf::ModelSignature signature;
-
-  signature.add_input(moco::tf::TensorName("ifm", 0));
-  signature.add_output(moco::tf::TensorName("ofm", 0));
-
-  tensorflow::GraphDef graph_def;
-
-  EXPECT_TRUE(plier::tf::parse_graphdef(conv2d_backprop_input_01_pbtxtdata, graph_def));
+  TFNodeBuildTester tester;
+  moco::tf::Conv2DBackpropInputGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
 
-  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+  EXPECT_TRUE(plier::tf::parse_nodedef(conv2d_backprop_input_01_pbtxtdata, nodedef));
 
   // what to test:
   // - All node inputs are valid
   // - All attributes are as expected
 
-  moco::tf::TFConv2DBackpropInput *tf_conv2d_backprop_input =
-      moco::tf::test::find_first_node_bytype<moco::tf::TFConv2DBackpropInput>(graph.get());
-  ASSERT_NE(tf_conv2d_backprop_input, nullptr);
-  ASSERT_NE(tf_conv2d_backprop_input->input_sizes(), nullptr);
-  ASSERT_NE(tf_conv2d_backprop_input->filter(), nullptr);
-  ASSERT_NE(tf_conv2d_backprop_input->out_backprop(), nullptr);
+  tester.inputs({"outshape", "weights", "ifm"});
+  tester.output("ofm");
+  tester.run(nodedef, graphbuilder);
 
-  ASSERT_EQ(tf_conv2d_backprop_input->padding(), "SAME");
-  ASSERT_EQ(tf_conv2d_backprop_input->data_layout(), "NHWC");
-  ASSERT_EQ(tf_conv2d_backprop_input->strides().size(), 4);
+  auto test_node = dynamic_cast<moco::tf::TFConv2DBackpropInput *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->padding(), "SAME");
+  ASSERT_EQ(test_node->data_layout(), "NHWC");
+  ASSERT_EQ(test_node->strides().size(), 4);
 }
index 64ae27d..f96be26 100644 (file)
  * limitations under the License.
  */
 
+#include "DepthwiseConv2dNative.h"
 #include "TestHelper.h"
-
-#include "Importer.h"
 #include "IR/TFDepthwiseConv2dNative.h"
 
-#include <loco/IR/TensorShape.h>
-#include <loco/IR/FeatureShape.h>
-#include <plier/tf/TestHelper.h>
-
 #include <gtest/gtest.h>
 
-#include <memory>
-
 using namespace moco::tf::test;
 
 namespace
 {
 // clang-format off
 const char *depthwise_conv2d_native_01_pbtxtdata = STRING_CONTENT(
-node {
-  name: "input"
-  op: "Placeholder"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "shape"
-    value {
-      shape {
-        dim {
-          size: 1
-        }
-        dim {
-          size: 4
-        }
-        dim {
-          size: 4
-        }
-        dim {
-          size: 3
-        }
-      }
-    }
-  }
-}
-node {
-  name: "filter"
-  op: "Placeholder"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "shape"
-    value {
-      shape {
-        dim {
-          size: 2
-        }
-        dim {
-          size: 2
-        }
-        dim {
-          size: 3
-        }
-        dim {
-          size: 2
-        }
-      }
-    }
-  }
-}
-node {
-  name: "depthwise/Shape"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_INT32
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_INT32
-        tensor_shape {
-          dim {
-            size: 4
-          }
-        }
-        int_val: 2
-        int_val: 2
-        int_val: 3
-        int_val: 2
-      }
-    }
-  }
-}
-node {
-  name: "depthwise/dilation_rate"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_INT32
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_INT32
-        tensor_shape {
-          dim {
-            size: 2
-          }
-        }
-        int_val: 1
-        int_val: 1
-      }
-    }
-  }
-}
-node {
   name: "depthwise"
   op: "DepthwiseConv2dNative"
   input: "input"
@@ -188,32 +70,29 @@ node {
       }
     }
   }
-}
 );
 // clang-format on
 } // namespace
 
 TEST(TensorFlowImport, Depthwise_conv2d_native)
 {
-  moco::tf::Importer importer;
-  moco::tf::ModelSignature signature;
-
-  signature.add_input(moco::tf::TensorName("input", 0));
-  signature.add_output(moco::tf::TensorName("depthwise", 0));
-
-  tensorflow::GraphDef graph_def;
+  TFNodeBuildTester tester;
+  moco::tf::DepthwiseConv2dNativeGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
 
-  EXPECT_TRUE(plier::tf::parse_graphdef(depthwise_conv2d_native_01_pbtxtdata, graph_def));
+  EXPECT_TRUE(plier::tf::parse_nodedef(depthwise_conv2d_native_01_pbtxtdata, nodedef));
 
-  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+  // what to test:
+  // - All node inputs are valid
+  // - All attributes are as expected
 
-  moco::tf::TFDepthwiseConv2dNative *tfdepthwiseconv2dnative =
-      moco::tf::test::find_first_node_bytype<moco::tf::TFDepthwiseConv2dNative>(graph.get());
-  ASSERT_NE(tfdepthwiseconv2dnative, nullptr);
-  ASSERT_NE(tfdepthwiseconv2dnative->input(), nullptr);
-  ASSERT_NE(tfdepthwiseconv2dnative->filter(), nullptr);
+  tester.inputs({"input", "filter"});
+  tester.output("depthwise");
+  tester.run(nodedef, graphbuilder);
 
-  ASSERT_EQ(tfdepthwiseconv2dnative->padding(), "VALID");
-  ASSERT_EQ(tfdepthwiseconv2dnative->data_layout(), "NHWC");
-  ASSERT_EQ(tfdepthwiseconv2dnative->strides().size(), 4);
+  auto test_node = dynamic_cast<moco::tf::TFDepthwiseConv2dNative *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->padding(), "VALID");
+  ASSERT_EQ(test_node->data_layout(), "NHWC");
+  ASSERT_EQ(test_node->strides().size(), 4);
 }
index e5daf31..f6d486a 100644 (file)
  * limitations under the License.
  */
 
+#include "FusedBatchNorm.h"
 #include "TestHelper.h"
-
-#include "Importer.h"
-
 #include "IR/TFFusedBatchNorm.h"
 
-#include <loco.h>
-#include <plier/tf/TestHelper.h>
-
 #include <gtest/gtest.h>
 
-#include <cstring>
-#include <memory>
-
 using namespace moco::tf::test;
 
 namespace
 {
 // clang-format off
 const char *fbn_basic_pbtxt = STRING_CONTENT(
-node {
-  name: "input"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value { type: DT_FLOAT }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim { size: 1 }
-          dim { size: 4 }
-          dim { size: 4 }
-          dim { size: 1 }
-        }
-        float_val: 1.0
-      }
-    }
-  }
-}
-node {
-  name: "gamma"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 1
-          }
-        }
-        float_val: 1.0
-      }
-    }
-  }
-}
-node {
-  name: "beta"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 1
-          }
-        }
-        float_val: 1.0
-      }
-    }
-  }
-}
-node {
-  name: "FBN_01/mean"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 1
-          }
-        }
-        float_val: 1.0
-      }
-    }
-  }
-}
-node {
-  name: "FBN_01/variance"
-  op: "Const"
-  attr {
-    key: "dtype"
-    value {
-      type: DT_FLOAT
-    }
-  }
-  attr {
-    key: "value"
-    value {
-      tensor {
-        dtype: DT_FLOAT
-        tensor_shape {
-          dim {
-            size: 1
-          }
-        }
-        float_val: 1.0
-      }
-    }
-  }
-}
-node {
   name: "FBN_01"
   op: "FusedBatchNorm"
   input: "input"
@@ -185,7 +57,6 @@ node {
       b: false
     }
   }
-}
 );
 // clang-format on
 
@@ -193,14 +64,11 @@ node {
 
 TEST(TensorFlowImport, tf_fbn_basic)
 {
-  // load graph
-  moco::tf::Importer importer;
-  moco::tf::ModelSignature signature;
-  signature.add_output(moco::tf::TensorName("FBN_01", 0));
+  TFNodeBuildTester tester;
+  moco::tf::FusedBatchNormGraphBuilder graphbuilder;
+  tensorflow::NodeDef nodedef;
 
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(fbn_basic_pbtxt, graph_def));
-  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+  EXPECT_TRUE(plier::tf::parse_nodedef(fbn_basic_pbtxt, nodedef));
 
   // what to test:
   // - there should exist a TFFusedBatchNorm
@@ -211,13 +79,11 @@ TEST(TensorFlowImport, tf_fbn_basic)
   // - variance() should not be nullptr
   // - epsilon() value should match
 
-  moco::tf::TFFusedBatchNorm *fbn_node =
-      moco::tf::test::find_first_node_bytype<moco::tf::TFFusedBatchNorm>(graph.get());
+  tester.inputs({"input", "gamma", "beta", "FBN_01/mean", "FBN_01/variance"});
+  tester.output("FBN_01");
+  tester.run(nodedef, graphbuilder);
 
-  ASSERT_NE(fbn_node->x(), nullptr);
-  ASSERT_NE(fbn_node->scale(), nullptr);
-  ASSERT_NE(fbn_node->offset(), nullptr);
-  ASSERT_NE(fbn_node->mean(), nullptr);
-  ASSERT_NE(fbn_node->variance(), nullptr);
-  ASSERT_EQ(fbn_node->epsilon(), 0.001f);
+  auto test_node = dynamic_cast<moco::tf::TFFusedBatchNorm *>(tester.output());
+  ASSERT_NE(test_node, nullptr);
+  ASSERT_EQ(test_node->epsilon(), 0.001f);
 }