Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / TestGraph.h
index 7356204..2865b0f 100644 (file)
@@ -18,7 +18,6 @@
 #define __TEST_GRAPH_H__
 
 #include <luci/IR/CircleNodes.h>
-#include "GraphBlock.h"
 
 #include <loco.h>
 
@@ -36,29 +35,29 @@ class TestGraph
 {
 public:
   std::unique_ptr<loco::Graph> g;
-  loco::Pull *pull;
-  loco::Push *push;
+  luci::CircleInput *input_node = nullptr;
+  luci::CircleOutput *output_node = nullptr;
 
   TestGraph() // creates Pull and Push
   {
     g = loco::make_graph();
 
-    pull = g->nodes()->create<loco::Pull>();
+    input_node = g->nodes()->create<luci::CircleInput>();
 
-    push = g->nodes()->create<loco::Push>();
+    output_node = g->nodes()->create<luci::CircleOutput>();
 
     auto input = g->inputs()->create();
     {
       input->name("input");
-      loco::link(input, pull);
+      luci::link(input, input_node);
     }
     auto output = g->outputs()->create();
     {
       output->name("output");
-      loco::link(output, push);
+      luci::link(output, output_node);
     }
 
-    _next_input = pull;
+    _next_input = input_node;
   }
 
   loco::Graph *graph() { return g.get(); }
@@ -73,7 +72,7 @@ public:
   }
 
   /// @brief Creates op T (arity=1) with arg1 as an input and appends it to graph
-  template <class T> T *append(loco::Node *arg1)
+  template <class T> T *append(luci::CircleNode *arg1)
   {
     auto node = g->nodes()->create<T>();
     setInput(node, arg1);
@@ -83,7 +82,7 @@ public:
   }
 
   /// @brief Creates op T (arity=2) with arg1, arg2 as inputs and appends it to graph
-  template <class T> T *append(loco::Node *arg1, loco::Node *arg2)
+  template <class T> T *append(luci::CircleNode *arg1, luci::CircleNode *arg2)
   {
     auto node = g->nodes()->create<T>();
     setInput(node, arg1, arg2);
@@ -93,7 +92,8 @@ public:
   }
 
   /// @brief Creates op T (arity=3) with arg1, arg2, arg3 as inputs and appends it to graph
-  template <class T> T *append(loco::Node *arg1, loco::Node *arg2, loco::Node *arg3)
+  template <class T>
+  T *append(luci::CircleNode *arg1, luci::CircleNode *arg2, luci::CircleNode *arg3)
   {
     auto node = g->nodes()->create<T>();
     setInput(node, arg1, arg2, arg3);
@@ -102,101 +102,68 @@ public:
     return node;
   }
 
-  // push will get the last appended node
-  void complete() { push->from(_next_input); }
+  // output will get the last appended node
+  void complete() { output_node->from(_next_input); }
 
-  void complete(loco::Node *last_node) { push->from(last_node); }
+  void complete(luci::CircleNode *last_node) { output_node->from(last_node); }
 
 private:
   // arity 1
-  void setInput(loco::Node *node, loco::Node *) { assert(false && "NYI"); };
-
-  void setInput(loco::AvgPool2D *node, loco::Node *input) { node->ifm(input); }
-  void setInput(loco::BiasDecode *node, loco::Node *input) { node->input(input); };
-  void setInput(loco::BiasEncode *node, loco::Node *input) { node->input(input); };
-  void setInput(loco::FeatureDecode *node, loco::Node *input) { node->input(input); };
-  void setInput(loco::FeatureEncode *node, loco::Node *input) { node->input(input); };
-  void setInput(loco::MaxPool2D *node, loco::Node *input) { node->ifm(input); }
-  void setInput(loco::Push *node, loco::Node *input) { node->from(input); };
-  void setInput(loco::ReLU *node, loco::Node *input) { node->input(input); };
-  void setInput(loco::ReLU6 *node, loco::Node *input) { node->input(input); };
-  void setInput(loco::Tanh *node, loco::Node *input) { node->input(input); };
-  void setInput(loco::TensorTranspose *node, loco::Node *input) { node->input(input); };
-
-  void setInput(luci::CircleAveragePool2D *node, loco::Node *input) { node->value(input); };
-  void setInput(luci::CircleMaxPool2D *node, loco::Node *input) { node->value(input); };
-  void setInput(luci::CircleRelu *node, loco::Node *input) { node->features(input); };
-  void setInput(luci::CircleRelu6 *node, loco::Node *input) { node->features(input); };
+  void setInput(luci::CircleNode *, luci::CircleNode *) { assert(false && "NYI"); };
 
-  // arity 2
-  void setInput(loco::Node *node, loco::Node *, loco::Node *) { assert(false && "NYI"); };
+  void setInput(luci::CircleAveragePool2D *node, luci::CircleNode *input) { node->value(input); };
+  void setInput(luci::CircleRelu *node, luci::CircleNode *input) { node->features(input); };
+  void setInput(luci::CircleSqueeze *node, luci::CircleNode *input) { node->input(input); };
 
-  void setInput(loco::Conv2D *node, loco::Node *input, loco::Node *filter)
+  void setInput(luci::CircleGatherNd *node, luci::CircleNode *params, luci::CircleNode *indices)
   {
-    node->ifm(input);
-    node->ker(filter);
-  }
-
-  void setInput(loco::EltwiseAdd *node, loco::Node *arg1, loco::Node *arg2)
-  {
-    node->lhs(arg1);
-    node->rhs(arg2);
+    node->params(params);
+    node->indices(indices);
   };
 
-  void setInput(loco::FeatureBiasAdd *node, loco::Node *arg1, loco::Node *arg2)
+  // arity 2
+  void setInput(luci::CircleNode *, luci::CircleNode *, luci::CircleNode *)
   {
-    node->value(arg1);
-    node->bias(arg2);
+    assert(false && "NYI");
   };
 
-  void setInput(luci::CircleAdd *node, loco::Node *arg1, loco::Node *arg2)
+  void setInput(luci::CircleExpandDims *node, luci::CircleNode *arg1, luci::CircleNode *arg2)
   {
-    node->x(arg1);
-    node->y(arg2);
+    node->input(arg1);
+    node->axis(arg2);
   };
 
-  void setInput(luci::CircleMul *node, loco::Node *arg1, loco::Node *arg2)
+  void setInput(luci::CircleTranspose *node, luci::CircleNode *arg1, luci::CircleNode *arg2)
   {
-    node->x(arg1);
-    node->y(arg2);
+    node->a(arg1);
+    node->perm(arg2);
   };
 
-  void setInput(luci::CircleSub *node, loco::Node *arg1, loco::Node *arg2)
+  void setInput(luci::CircleResizeBilinear *node, luci::CircleNode *input, luci::CircleNode *size)
   {
-    node->x(arg1);
-    node->y(arg2);
+    node->input(input);
+    node->size(size);
   };
 
-  void setInput(luci::CircleTranspose *node, loco::Node *arg1, loco::Node *arg2)
+  void setInput(luci::CircleResizeNearestNeighbor *node, luci::CircleNode *input,
+                luci::CircleNode *size)
   {
-    node->a(arg1);
-    node->perm(arg2);
+    node->input(input);
+    node->size(size);
   };
 
   // arity 3
-  void setInput(loco::Node *node, loco::Node *, loco::Node *, loco::Node *)
+  void setInput(luci::CircleNode *, luci::CircleNode *, luci::CircleNode *, luci::CircleNode *)
   {
     assert(false && "NYI");
   };
 
-  void setInput(luci::CircleConv2D *node, loco::Node *input, loco::Node *filter, loco::Node *bias)
-  {
-    node->input(input);
-    node->filter(filter);
-    node->bias(bias);
-  }
-
 private:
   loco::Node *_next_input;
 };
 
 enum class ExampleGraphType
 {
-  FeatureBiasAdd,
-  ConstGen_ReLU,
-  FilterEncode_FilterDecode,
-  Transpose,
-
   CircleTranspose,
 };
 
@@ -205,109 +172,42 @@ template <ExampleGraphType T> class ExampleGraph;
 /**
  * @brief Class to create the following:
  *
- *   Pull - FeatureEncoder - FeatureBiasAdd - FeatureDecode - Push
- *                             |
- *     ConstGen - BiasEncode --+
+ *     CircleInput -- CircleTranspose -- CircleOutput
  */
-template <> class ExampleGraph<ExampleGraphType::FeatureBiasAdd> : public TestGraph
+template <> class ExampleGraph<ExampleGraphType::CircleTranspose> : public TestGraph
 {
 public:
-  loco::FeatureEncode *fea_enc = nullptr;
-  loco::ConstGen *constgen = nullptr;
-  loco::BiasEncode *bias_enc = nullptr;
-  loco::FeatureBiasAdd *fea_bias_add = nullptr;
-  loco::FeatureDecode *fea_dec = nullptr;
+  luci::CircleConst *const_perm = nullptr;
+  luci::CircleTranspose *transpose_node = nullptr;
 
 public:
   ExampleGraph()
   {
-    fea_enc = luci::make_feature_encode<luci::FeatureLayout::NHWC>(pull);
-    constgen = append<loco::ConstGen>();
-    bias_enc = append<loco::BiasEncode>(constgen);
-    fea_bias_add = append<loco::FeatureBiasAdd>(fea_enc, bias_enc);
-    fea_dec = luci::make_feature_decode<luci::FeatureLayout::NHWC>(fea_bias_add);
-    complete(fea_dec);
+    const_perm = append<luci::CircleConst>();
+    transpose_node = append<luci::CircleTranspose>(input_node, const_perm);
+    complete(transpose_node);
   }
 };
 
-/**
- * @brief Class to creates the following:
- *
- *     ConstGen -- ReLU -- Push
- */
-template <> class ExampleGraph<ExampleGraphType::ConstGen_ReLU> : public TestGraph
-{
-public:
-  loco::ConstGen *constgen = nullptr;
-  loco::ReLU *relu = nullptr;
-
-public:
-  ExampleGraph()
-  {
-    constgen = append<loco::ConstGen>();
-    relu = append<loco::ReLU>(constgen);
-    complete(relu);
-  }
-};
+} // namespace test
+} // namespace luci
 
-/**
- * @brief Class to creates the following:
- *
- *     Pull -- Transpose -- Push
- */
-template <> class ExampleGraph<ExampleGraphType::Transpose> : public TestGraph
+namespace luci
 {
-public:
-  loco::TensorTranspose *transpose = nullptr;
-
-public:
-  ExampleGraph()
-  {
-    transpose = append<loco::TensorTranspose>(pull);
-    complete(transpose);
-  }
-};
-
-/**
- * @brief Class to creates the following:
- *
- *     Pull -- FilterEncode -- FilterDecode -- Push
- */
-template <> class ExampleGraph<ExampleGraphType::FilterEncode_FilterDecode> : public TestGraph
+namespace test
 {
-public:
-  loco::FilterEncode *filterEncode = nullptr;
-  loco::FilterDecode *filterDecode = nullptr;
 
-public:
-  ExampleGraph()
-  {
-    filterEncode = luci::make_filter_encode<luci::FilterLayout::HWIO>(pull); // from Tensorflow
-    filterDecode =
-        luci::make_filter_decode<luci::FilterLayout::OHWI>(filterEncode); // to Tensorflow Lite
-    complete(filterDecode);
-  }
-};
+/// @brief This will set GraphInput shape from CircleInput shape
+void graph_input_shape(luci::CircleInput *input);
 
-/**
- * @brief Class to create the following:
- *
- *     Pull -- CircleTranspose -- Push
- */
-template <> class ExampleGraph<ExampleGraphType::CircleTranspose> : public TestGraph
-{
-public:
-  loco::ConstGen *const_perm = nullptr;
-  luci::CircleTranspose *transpose_node = nullptr;
+/// @brief This will set GraphOutput shape from CircleOutput shape
+void graph_output_shape(luci::CircleOutput *output);
 
-public:
-  ExampleGraph()
-  {
-    const_perm = append<loco::ConstGen>();
-    transpose_node = append<luci::CircleTranspose>(pull, const_perm);
-    complete(transpose_node);
-  }
-};
+/// @brief This will set GraphInput dtype from CircleInput dtype
+void graph_input_dtype(luci::CircleInput *input);
+
+/// @brief This will set GraphOutput dtype from CircleOutput dtype
+void graph_output_dtype(luci::CircleOutput *output);
 
 } // namespace test
 } // namespace luci