Imported Upstream version 1.19.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / CircleReader.cpp
index 6c9bf3a..14917ba 100644 (file)
@@ -29,12 +29,26 @@ bool is_valid(const circle::OperatorCodeT &opcode)
   return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
 }
 
+bool is_valid(const circle::OperatorCode *opcode)
+{
+  assert(opcode != nullptr);
+  circle::BuiltinOperator code = opcode->builtin_code();
+  return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
+}
+
 bool is_custom(const circle::OperatorCodeT &opcode)
 {
   circle::BuiltinOperator code = opcode.builtin_code;
   return (code == circle::BuiltinOperator_CUSTOM);
 }
 
+bool is_custom(const circle::OperatorCode *opcode)
+{
+  assert(opcode != nullptr);
+  circle::BuiltinOperator code = opcode->builtin_code();
+  return (code == circle::BuiltinOperator_CUSTOM);
+}
+
 std::string opcode_name(const circle::OperatorCodeT &opcode)
 {
   if (!is_valid(opcode))
@@ -56,6 +70,30 @@ std::string opcode_name(const circle::OperatorCodeT &opcode)
   return circle::EnumNameBuiltinOperator(code);
 }
 
+std::string opcode_name(const circle::OperatorCode *opcode)
+{
+  assert(opcode != nullptr);
+
+  if (!is_valid(opcode))
+  {
+    std::ostringstream oss;
+    oss << "(invalid)";
+    return oss.str();
+  }
+
+  if (is_custom(opcode))
+  {
+    auto custom_code = opcode->custom_code()->str();
+    if (custom_code.empty())
+      return "(invalid custom)";
+
+    return custom_code;
+  }
+
+  circle::BuiltinOperator code = opcode->builtin_code();
+  return circle::EnumNameBuiltinOperator(code);
+}
+
 const char *tensor_name(const circle::TensorT &tensor)
 {
   static const char *kEmptyTensorName = "(noname)";
@@ -66,11 +104,30 @@ const char *tensor_name(const circle::TensorT &tensor)
   return kEmptyTensorName;
 }
 
+const char *tensor_name(const circle::Tensor *tensor)
+{
+  assert(tensor != nullptr);
+
+  static const char *kEmptyTensorName = "(noname)";
+  const auto tensor_name = tensor->name()->c_str();
+
+  if (!std::string(tensor_name).empty())
+    return tensor_name;
+
+  return kEmptyTensorName;
+}
+
 const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
 {
   return tensor.quantization.get();
 }
 
+const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor)
+{
+  assert(tensor != nullptr);
+  return tensor->quantization();
+}
+
 loco::DataType luci_datatype(const circle::TensorType type)
 {
   switch (type)
@@ -235,6 +292,16 @@ luci_quantparam(const circle::QuantizationParametersT *quantization)
   return nullptr;
 }
 
+std::unique_ptr<CircleQuantParam> luci_quantparam(const circle::QuantizationParameters *qparams)
+{
+  // create temporary unpacked API object
+  assert(qparams != nullptr);
+  circle::QuantizationParametersT quantization;
+  qparams->UnPackTo(&quantization);
+
+  return luci_quantparam(&quantization);
+}
+
 std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParametersT *sparsity)
 {
   assert(sparsity);
@@ -257,6 +324,16 @@ std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParamete
   return sparsityparam;
 }
 
+std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParameters *sparparam)
+{
+  // create temporary unpacked API object
+  assert(sparparam != nullptr);
+  circle::SparsityParametersT sparsity;
+  sparparam->UnPackTo(&sparsity);
+
+  return luci_sparsityparam(&sparsity);
+}
+
 void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
 {
   node->name(tensor_name(tensor));
@@ -292,6 +369,45 @@ void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
   }
 }
 
+void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node)
+{
+  assert(tensor != nullptr);
+
+  node->name(tensor_name(tensor));
+  node->dtype(luci_datatype(tensor->type()));
+
+  const auto tensor_shape_signature = wrap(tensor->shape_signature());
+  const auto tensor_shape = wrap(tensor->shape());
+  assert(tensor_shape_signature.size() == 0 ||
+         tensor_shape_signature.size() == tensor_shape.size());
+
+  const auto dims = tensor_shape; // in NHWC
+  node->rank(dims.size());
+  for (uint32_t r = 0; r < dims.size(); ++r)
+  {
+    if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1)
+      node->dim(r).unset();
+    else
+      node->dim(r).set(dims[r]);
+  }
+
+  const auto quantization = tensor->quantization();
+  if (quantization != nullptr)
+  {
+    auto quantparam = luci_quantparam(quantization);
+    if (quantparam)
+      node->quantparam(std::move(quantparam));
+  }
+
+  const auto sparsity = tensor->sparsity();
+  if (sparsity != nullptr)
+  {
+    auto sparsityparam = luci_sparsityparam(sparsity);
+    if (sparsityparam)
+      node->sparsityparam(std::move(sparsityparam));
+  }
+}
+
 circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
 {
   const auto &op_codes = opcodes();
@@ -326,7 +442,7 @@ bool CircleReader::parse(const circle::Model *model)
   _model.reset(model->UnPack());
 
   // for direct pointer access
-  _model_ptr = model;
+  _native_model = model;
 
   return true;
 }
@@ -342,12 +458,72 @@ bool CircleReader::select_subgraph(uint32_t sgindex)
   _current_subgraph = _model->subgraphs[sgindex].get();
 
   // for direct pointer access
-  auto subgraphs = _model_ptr->subgraphs();
-  const circle::SubGraph *subgraph = (*subgraphs)[sgindex];
+  auto subgraphs = _native_model->subgraphs();
+  assert(subgraphs != nullptr);
+
+  _native_subgraph = subgraphs->Get(sgindex);
+  assert(_native_subgraph != nullptr);
 
-  _tensors_ptr = subgraph->tensors();
+  _tensors_ptr = _native_subgraph->tensors();
 
   return true;
 }
 
+template <typename T>
+VectorWrapper<T>::VectorWrapper(const flatbuffers::Vector<T> *ptr) : _vector(ptr)
+{
+  // Do nothing
+}
+
+template <typename T> uint32_t VectorWrapper<T>::size() const
+{
+  return null() ? 0 : _vector->size();
+}
+
+template <typename T> const T *VectorWrapper<T>::data() const
+{
+  return null() ? nullptr : _vector->data();
+}
+
+template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::begin() const
+{
+  return null() ? iterator(nullptr, 0) : _vector->begin();
+}
+
+template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::end() const
+{
+  return null() ? begin() : _vector->end();
+}
+
+template <typename T> typename VectorWrapper<T>::value_type VectorWrapper<T>::at(uint32_t i) const
+{
+  if (i >= size())
+  {
+    // TODO find better error message
+    throw std::range_error("Access to prohibited vector element");
+  }
+
+  return _vector->Get(i);
+}
+
+template <typename T>
+typename VectorWrapper<T>::value_type VectorWrapper<T>::operator[](uint32_t i) const
+{
+  return at(i);
+}
+
+template <typename T> bool VectorWrapper<T>::null() const { return _vector == nullptr; }
+template <typename T> bool VectorWrapper<T>::empty() const { return size() == 0; }
+
+#define REGISTER_WRAPPER(T) template class VectorWrapper<T>
+REGISTER_WRAPPER(flatbuffers::Offset<circle::SubGraph>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Buffer>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Tensor>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Operator>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::OperatorCode>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Metadata>);
+REGISTER_WRAPPER(int32_t);
+REGISTER_WRAPPER(uint8_t);
+#undef REGISTER_WRAPPER
+
 } // namespace luci