Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / loader / KernelBuilder.h
index 7e30d39..d5c5a4b 100644 (file)
 
 #include <memory>
 #include <vector>
+#include <unordered_map>
 
 namespace luci_interpreter
 {
 
-class GraphLoader;
-class ModuleLoader;
-
 class KernelBuilder : public luci::CircleNodeVisitor<std::unique_ptr<Kernel>>
 {
 public:
-  KernelBuilder(const ModuleLoader &module_loader, const GraphLoader &graph_loader)
-      : _module_loader(module_loader), _graph_loader(graph_loader)
+  KernelBuilder(
+      const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
+      const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
+      : _graph_to_runtime_graph(graph_to_runtime_graph), _node_to_tensor(node_to_tensor)
   {
   }
 
@@ -45,6 +45,7 @@ public:
   std::unique_ptr<Kernel> visit(const luci::CircleConcatenation *node) override;
   std::unique_ptr<Kernel> visit(const luci::CircleConv2D *node) override;
   std::unique_ptr<Kernel> visit(const luci::CircleConst *node) override;
+  std::unique_ptr<Kernel> visit(const luci::CircleDepthToSpace *node) override;
   std::unique_ptr<Kernel> visit(const luci::CircleDepthwiseConv2D *node) override;
   std::unique_ptr<Kernel> visit(const luci::CircleElu *node) override;
   std::unique_ptr<Kernel> visit(const luci::CircleFullyConnected *node) override;
@@ -61,6 +62,8 @@ public:
   std::unique_ptr<Kernel> visit(const luci::CircleOutput *node) override;
   std::unique_ptr<Kernel> visit(const luci::CirclePad *node) override;
   std::unique_ptr<Kernel> visit(const luci::CircleReshape *node) override;
+  std::unique_ptr<Kernel> visit(const luci::CircleReverseV2 *node) override;
+  std::unique_ptr<Kernel> visit(const luci::CircleSlice *node) override;
   std::unique_ptr<Kernel> visit(const luci::CircleSoftmax *node) override;
   std::unique_ptr<Kernel> visit(const luci::CircleSpaceToDepth *node) override;
   std::unique_ptr<Kernel> visit(const luci::CircleSplit *node) override;
@@ -82,8 +85,8 @@ private:
   RuntimeGraph *getRuntimeGraph(const loco::Graph *graph) const;
 
 private:
-  const ModuleLoader &_module_loader;
-  const GraphLoader &_graph_loader;
+  const std::unordered_map<const loco::Graph *, RuntimeGraph *> &_graph_to_runtime_graph;
+  const std::unordered_map<const loco::Node *, Tensor *> &_node_to_tensor;
 };
 
 } // namespace luci_interpreter