Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ConstantInitializer.cc
index 71e3136..deb27f0 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include "ConstantInitializer.h"
+#include "Tensor.h"
 
 namespace onert
 {
@@ -30,39 +31,61 @@ ConstantInitializer::ConstantInitializer(const ir::Operands &operands,
   // DO NOTHING
 }
 
+void ConstantInitializer::registerDefaultInitializer(const ir::OperandIndex &index,
+                                                     const ir::Operand &obj)
+{
+  registerExternalInitializer(index, obj);
+}
+
+void ConstantInitializer::registerExternalInitializer(const ir::OperandIndex &index,
+                                                      const ir::Operand &obj)
+{
+  // For only CONSTANTS
+  // TODO Add to check if tensor has been allocated
+  if (!obj.isConstant())
+    return;
+
+  _init_map[index] = [](const onert::ir::Operand &model_obj, onert::backend::ITensor &itensor) {
+    auto data = model_obj.shareData();
+    assert(data && data->base());
+    ExternalTensor &tensor = dynamic_cast<ExternalTensor &>(itensor);
+    tensor.setData(data);
+  };
+}
+
 void ConstantInitializer::visit(const ir::operation::Conv2D &node)
 {
   const auto &kernel_index = node.getInputs().at(ir::operation::Conv2D::KERNEL);
   const auto &kernel_obj = _operands.at(kernel_index);
-  registerCopyInitializer(kernel_index, kernel_obj);
+  registerExternalInitializer(kernel_index, kernel_obj);
 
   const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerCopyInitializer(bias_index, bias_obj);
+  registerExternalInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const ir::operation::DepthwiseConv2D &node)
 {
   const auto &kernel_index = node.getInputs().at(ir::operation::DepthwiseConv2D::KERNEL);
   const auto &kernel_obj = _operands.at(kernel_index);
-  registerCopyInitializer(kernel_index, kernel_obj);
+  registerExternalInitializer(kernel_index, kernel_obj);
 
   const auto &bias_index = node.getInputs().at(ir::operation::DepthwiseConv2D::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerCopyInitializer(bias_index, bias_obj);
+  registerExternalInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const ir::operation::FullyConnected &node)
 {
   const auto &weight_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT);
   const auto &weight_obj = _operands.at(weight_index);
-  registerCopyInitializer(weight_index, weight_obj);
+  registerExternalInitializer(weight_index, weight_obj);
 
   const auto &bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS);
   if (!bias_index.undefined())
   {
     const auto &bias_obj = _operands.at(bias_index);
-    registerCopyInitializer(bias_index, bias_obj);
+    registerExternalInitializer(bias_index, bias_obj);
   }
 }