[TfLite] Implement TfOpIdxMap
authorJihoon Lee <jhoon.it.lee@samsung.com>
Tue, 13 Apr 2021 13:16:11 +0000 (22:16 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 17 May 2021 04:20:56 +0000 (13:20 +0900)
This patch implments TfOpIdxMap
Please note that graph::initialize has been extracted for modularity and
testability within this patch.

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/compiler/tflite_interpreter.cpp
nntrainer/graph/network_graph.cpp
nntrainer/graph/network_graph.h
nntrainer/models/neuralnet.cpp
test/unittest/compiler/unittest_interpreter.cpp

index 1b81725faa3de7e25f45215ff0127d2c01736644..0bd3547cbcbe4bf4815e47cb121240e7713efe6d 100644 (file)
@@ -15,6 +15,7 @@
 #include <fstream>
 #include <memory>
 #include <string>
+#include <tuple>
 
 #include <tf_schema_generated.h>
 
@@ -199,30 +200,114 @@ private:
 using TfOpNodes = std::vector<TfOpNode>;
 
 /**
- * @brief tensorflow operation index map, this class manages operation index
- * mapping
+ * @brief Bidirectional Index map
  *
+ * @tparam T type of a underlying value, please note that T will be copied, so
+ * please use this for pointers and primitive values that is okay to copy
  */
-class TfOpIdxMap {
+template <typename T> class BidirectionalIndexMap {
 public:
-  TfOpIdxMap(std::vector<TfOpNode> nodes){};
+  /**
+   * @brief addDatapoint to the map
+   *
+   * @param data data to be added if there is no occurrence, data will be
+   * copied.
+   */
+  void addDataWhenNotFound(T data) {
+    auto search = data2index.find(data);
+
+    if (search == data2index.end()) {
+      data2index[data] = index2data.size();
+      index2data.push_back(data);
+    }
+  }
 
-private:
   /**
-   * @brief Bidirectional Index map
+   * @brief Get the Index of the data
    *
-   * @tparam T type of a underyling value
+   * @param key data that will be the key
+   * @return unsigned int index
    */
-  template <typename T> class BidirectionalIndexMap {
-    std::unordered_map<T, unsigned int> data2index; /**< data -> index map */
-    std::vector<T> index2data;                      /**< index -> data map */
-  };
+  unsigned int getIndex(const T &key) const {
+    auto search = data2index.find(key);
+
+    NNTR_THROW_IF(search == data2index.end(), std::invalid_argument)
+      << FUNC_TAG << "Cannot find index for key: " << key;
 
+    return *search;
+  }
+
+  /**
+   * @brief Get the Data object
+   *
+   * @param idx index to be searched
+   * @return T datapoint T
+   */
+  T getData(unsigned int index) const {
+    NNTR_THROW_IF(index >= index2data.size(), std::invalid_argument)
+      << FUNC_TAG << "Cannot find data for index: " << index;
+
+    return index2data[index];
+  }
+
+private:
+  std::unordered_map<T, unsigned int> data2index; /**< data -> index map */
+  std::vector<T> index2data;                      /**< index -> data map */
+};
+
+/**
+ * @brief tensorflow operation index map, this class manages operation index
+ * mapping
+ *
+ */
+class TfOpIdxMap {
+public:
+  TfOpIdxMap(const TfOpNodes &nodes) {
+    auto &opcode_map = getIndexMap<tflite::BuiltinOperator>();
+    auto update_opcode = [&opcode_map](tflite::BuiltinOperator opcode) {
+      opcode_map.addDataWhenNotFound(opcode);
+    };
+
+    auto &buffer_map = getIndexMap<const float *>();
+    buffer_map.addDataWhenNotFound(empty_buffer); /// put empty buffer to first
+
+    auto update_buffers = [&buffer_map](const TfOpNode::Variables &variables) {
+      for (auto &variable : variables) {
+        const Tensor &t = variable->getVariableRef();
+        if (!t.uninitialized() && t.isAllocated()) {
+          buffer_map.addDataWhenNotFound(t.getData());
+        }
+      }
+    };
+
+    auto &variable_map = getIndexMap<const Var_Grad *>();
+    auto update_variables =
+      [&variable_map](const TfOpNode::Variables &variables) {
+        for (auto &variable : variables) {
+          variable_map.addDataWhenNotFound(variable);
+        }
+      };
+
+    for (auto &op_node : nodes) {
+      update_opcode(op_node.getOpType());
+      update_variables(op_node.getInputs());
+      update_variables(op_node.getOutputs());
+      update_variables(op_node.getWeights());
+      update_buffers(op_node.getWeights());
+    }
+  }
+
+  template <typename T> BidirectionalIndexMap<T> &getIndexMap() {
+    return std::get<BidirectionalIndexMap<T>>(maps);
+  }
+
+private:
   float empty_buffer[0]; /**< unintialized tensor points to this buffer */
 
-  BidirectionalIndexMap<float *> buffer_map; /**< underlying buffer map */
-  BidirectionalIndexMap<tflite::BuiltinOperator> opcode_map; /**< opcode map */
-  BidirectionalIndexMap<Var_Grad *> variable_map;            /**< tensor map */
+  std::tuple<BidirectionalIndexMap<const float *>, /**< underlying buffer map */
+             BidirectionalIndexMap<tflite::BuiltinOperator>, /**< opcode map */
+             BidirectionalIndexMap<const Var_Grad *>>        /**< tensor map */
+    maps;
 };
 
 TfOpNodes
@@ -231,7 +316,6 @@ buildOpNodes(std::shared_ptr<const GraphRepresentation> representation) {
   /// @todo, look ahead of layers to get nodes that can be fused
   for (const auto &ln : representation->getSorted()) {
     nodes.emplace_back(*ln->getObject());
-    std::cout << ln->getObject()->getName() << '\n';
   }
 
   return nodes;
index 0e76acab7dcca40ed05d89b0eb6c9b42719e4470..0c3653f527046306beb2d201627a93cc02412c74 100644 (file)
@@ -925,4 +925,94 @@ std::vector<std::shared_ptr<LayerNode>> &NetworkGraph::getSorted() {
   return Sorted;
 }
 
+int NetworkGraph::initialize(std::shared_ptr<Manager> manager) {
+  int status = ML_ERROR_NONE;
+
+  for (unsigned int idx = 0; idx < Sorted.size(); ++idx) {
+    bool first = idx == 0;
+    auto &lnode = getSortedLayerNode(idx);
+    auto &lptr = lnode->getObject();
+    ml_logd("layer name : %s", lptr->getName().c_str());
+    std::string cur_type;
+    if (lptr->getType() == TimeDistLayer::type) {
+      cur_type =
+        std::dynamic_pointer_cast<TimeDistLayer>(lptr)->getDistLayerType();
+    } else {
+      cur_type = lptr->getType();
+    }
+
+    /**
+     * Set input dimension for all the layers.
+     * For input layer, as input dimension is known, set input tensor.
+     */
+    if (!first) {
+      std::string l_pre_type =
+        getSortedLayerNode(idx - 1)->getObject()->getType();
+      if (l_pre_type == TimeDistLayer::type) {
+        l_pre_type = std::dynamic_pointer_cast<TimeDistLayer>(
+                       getSortedLayerNode(idx - 1)->getObject())
+                       ->getDistLayerType();
+      }
+
+      if (istrequal(l_pre_type, ActivationLayer::type) &&
+          istrequal(cur_type, ActivationLayer::type)) {
+        ml_loge("double activation is not allowed");
+        return ML_ERROR_INVALID_PARAMETER;
+      }
+
+      for (unsigned int i = 0; i < lptr->input_layers.size(); ++i) {
+        Layer &in_layer = *getLayerNode(lptr->input_layers[i])->getObject();
+
+        unsigned int location = 0;
+        for (unsigned int j = 0; j < in_layer.output_layers.size(); ++j) {
+          if (in_layer.output_layers[j] == lptr->getName()) {
+            location = j;
+            break;
+          }
+        }
+
+        lptr->setInputDimension(in_layer.getOutputDimension()[location], i);
+      }
+    }
+
+    /**
+     * Initialize all the layers, allocate output tensors for each layer
+     * and add optimizer related weights for the layer
+     */
+    status = lptr->initialize(*manager);
+    NN_RETURN_STATUS();
+
+    auto &in_out = manager->trackLayerOutputs(cur_type, lptr->getName(),
+                                              lptr->getOutputDimension(),
+                                              lptr->getInputDimension());
+    lptr->setOutputBuffers(in_out);
+
+    /** Connect the output of the previous layers with the input of the current
+     * layer */
+    if (!first) {
+      for (unsigned int i = 0; i < lptr->input_layers.size(); ++i) {
+        Layer &in_layer = *getLayerNode(lptr->input_layers[i])->getObject();
+
+        unsigned int location = 0;
+        for (unsigned int j = 0; j < in_layer.output_layers.size(); ++j) {
+          if (in_layer.output_layers[j] == lptr->getName()) {
+            location = j;
+            break;
+          }
+        }
+
+        lptr->net_input[i] = getLayerNode(lptr->input_layers[i])
+                               ->getObject()
+                               ->net_hidden[location];
+      }
+    } else {
+      auto &in_out = manager->trackLayerInputs(cur_type, lptr->getName(),
+                                               lptr->getInputDimension(),
+                                               lptr->getOutputDimension());
+      lptr->setInputBuffers(in_out);
+    }
+  }
+  return status;
+}
+
 } /* namespace nntrainer */
index febc678faf6f907c04488f0ad219a8e668fdb229..1239deb6a45b3b4713355f2b470a5947685dd5fb 100644 (file)
@@ -242,6 +242,14 @@ public:
     return *this;
   }
 
+  /**
+   * @brief initialize network graph, with given manager
+   * @note this is taken from neuralnet, This might need some changes
+   *
+   * @param manager manager to allocate tensors
+   */
+  int initialize(std::shared_ptr<Manager> manager);
+
 private:
   std::map<std::string, std::string> sub_in_out; /** This is map to identify
                    input and output layer name of subgraph */
index 766ef91b801e526e32ab59b6486d6beaf64f5b7c..afdf5103a00fce172f0bbb135a5ffc8fdfac7514 100644 (file)
@@ -168,93 +168,8 @@ int NeuralNetwork::initialize() {
 
   setBatchSize();
 
-  for (unsigned int idx = 0; idx < n_layers; ++idx) {
-    bool first = idx == 0;
-    auto &lnode = model_graph.getSortedLayerNode(idx);
-    auto &lptr = lnode->getObject();
-    ml_logd("layer name : %s", lptr->getName().c_str());
-    std::string cur_type;
-    if (lptr->getType() == TimeDistLayer::type) {
-      cur_type =
-        std::dynamic_pointer_cast<TimeDistLayer>(lptr)->getDistLayerType();
-    } else {
-      cur_type = lptr->getType();
-    }
-
-    /**
-     * Set input dimension for all the layers.
-     * For input layer, as input dimension is known, set input tensor.
-     */
-    if (!first) {
-      std::string l_pre_type =
-        model_graph.getSortedLayerNode(idx - 1)->getObject()->getType();
-      if (l_pre_type == TimeDistLayer::type) {
-        l_pre_type = std::dynamic_pointer_cast<TimeDistLayer>(
-                       model_graph.getSortedLayerNode(idx - 1)->getObject())
-                       ->getDistLayerType();
-      }
-      if (istrequal(l_pre_type, ActivationLayer::type) &&
-          istrequal(cur_type, ActivationLayer::type)) {
-        ml_loge("double activation is not allowed");
-        return ML_ERROR_INVALID_PARAMETER;
-      }
-
-      for (unsigned int i = 0; i < lptr->input_layers.size(); ++i) {
-        Layer &in_layer =
-          *model_graph.getLayerNode(lptr->input_layers[i])->getObject();
-
-        unsigned int location = 0;
-        for (unsigned int j = 0; j < in_layer.output_layers.size(); ++j) {
-          if (in_layer.output_layers[j] == lptr->getName()) {
-            location = j;
-            break;
-          }
-        }
-
-        lptr->setInputDimension(in_layer.getOutputDimension()[location], i);
-      }
-    }
-
-    /**
-     * Initialize all the layers, allocate output tensors for each layer
-     * and add optimizer related weights for the layer
-     */
-    status = lptr->initialize(*manager);
-    NN_RETURN_STATUS();
-
-    REGISTER_EVENT(lptr->getName(), lnode->event_key)
-
-    auto &in_out = manager->trackLayerOutputs(cur_type, lptr->getName(),
-                                              lptr->getOutputDimension(),
-                                              lptr->getInputDimension());
-    lptr->setOutputBuffers(in_out);
-
-    /** Connect the output of the previous layers with the input of the current
-     * layer */
-    if (!first) {
-      for (unsigned int i = 0; i < lptr->input_layers.size(); ++i) {
-        Layer &in_layer =
-          *model_graph.getLayerNode(lptr->input_layers[i])->getObject();
-
-        unsigned int location = 0;
-        for (unsigned int j = 0; j < in_layer.output_layers.size(); ++j) {
-          if (in_layer.output_layers[j] == lptr->getName()) {
-            location = j;
-            break;
-          }
-        }
-
-        lptr->net_input[i] = model_graph.getLayerNode(lptr->input_layers[i])
-                               ->getObject()
-                               ->net_hidden[location];
-      }
-    } else {
-      auto &in_out = manager->trackLayerInputs(cur_type, lptr->getName(),
-                                               lptr->getInputDimension(),
-                                               lptr->getOutputDimension());
-      lptr->setInputBuffers(in_out);
-    }
-  }
+  status = model_graph.initialize(manager);
+  NN_RETURN_STATUS();
 
   // initialize optimizer and related variables
   if (opt) {
index 45abea645f6d9c96567a847cbfad0c66209d21c9..a4d47b50fc8676923beee9f248e4a55f1e93d6f1 100644 (file)
@@ -161,16 +161,28 @@ TEST_P(nntrainerInterpreterTest, graphSerializeAfterDeserialize) {
 }
 
 auto fc0 = LayerReprentation("fully_connected",
-                             {"name=fc0", "unit=1", "input_shape=1:1:100"});
+                             {"name=fc0", "unit=1", "input_shape=1:1:10"});
+auto fc1 = LayerReprentation("fully_connected",
+                             {"name=fc1", "unit=1", "input_shape=1:1:10"});
 
 auto flatten = LayerReprentation("flatten", {"name=flat"});
 
 #ifdef ENABLE_TFLITE_INTERPRETER
 TEST(flatbuffer, playground) {
+
+  auto manager = std::make_shared<nntrainer::Manager>();
+
   nntrainer::TfliteInterpreter interpreter;
-  auto g = makeGraph({fc0});
-  g->compile(nntrainer::LossType::LOSS_NONE);
+  auto g = makeGraph({fc0, fc1});
+  EXPECT_EQ(g->compile(nntrainer::LossType::LOSS_NONE), ML_ERROR_NONE);
+  EXPECT_EQ(g->initialize(manager), ML_ERROR_NONE);
+
+  manager->initializeWeights();
+  manager->allocateWeights();
+
   interpreter.serialize(g, "test.tflite");
+
+  manager->deallocateWeights();
 }
 #endif
 /**