[enco] Record Input/Output Shape metadata (#1153)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 24 Aug 2018 00:33:29 +0000 (09:33 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 24 Aug 2018 00:33:29 +0000 (09:33 +0900)
With this commit, generated C++ code includes the shape metadata of each
input and output.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/enco/core/src/CppCode.cpp
contrib/enco/core/src/CppGen/Global.cpp
contrib/enco/core/src/CppGen/Global.h
contrib/enco/core/src/Dims.h [new file with mode: 0644]

index 73b6a6a..93063b6 100644 (file)
@@ -2,6 +2,8 @@
 
 #include "CppGen/Global.h"
 
+#include "Dims.h"
+
 #include <pp/LinearDocument.h>
 #include <pp/MultiLineTextUtils.h>
 
@@ -33,11 +35,13 @@ void CppCode::dump(std::ostream &os) const
   {
     net_def.append("struct ", name, " {");
     net_def.indent();
+    net_def.append("struct Shape { uint32_t rank; const uint32_t *dims; };");
     net_def.append("struct Input {");
     net_def.indent();
     net_def.append("const char *name;");
     net_def.append("const uint8_t *ptr;");
     net_def.append("unsigned len;");
+    net_def.append("Shape shape;");
     net_def.unindent();
     net_def.append("};");
     net_def.append("struct Output {");
@@ -45,6 +49,7 @@ void CppCode::dump(std::ostream &os) const
     net_def.append("const char *name;");
     net_def.append("uint8_t *ptr;");
     net_def.append("unsigned len;");
+    net_def.append("Shape shape;");
     net_def.unindent();
     net_def.append("};");
     net_def.append();
@@ -69,20 +74,28 @@ void CppCode::dump(std::ostream &os) const
     for (uint32_t n = 0; n < m->input()->size(); ++n)
     {
       auto input = m->input()->at(n);
+      auto dims = as_dims(input->shape());
 
       auto name_var = global.constant(input->name());
+      auto dims_var = global.constant<uint32_t>(dims);
 
       net_ctor.append("inputs.at(", n, ").name = ", name_var, ";");
+      net_ctor.append("inputs.at(", n, ").shape.rank = ", dims.size(), ";");
+      net_ctor.append("inputs.at(", n, ").shape.dims = ", dims_var, ";");
     }
 
     // Initialize output metadata
     for (uint32_t n = 0; n < m->output()->size(); ++n)
     {
       auto output = m->output()->at(n);
+      auto dims = as_dims(output->shape());
 
       auto name_var = global.constant(output->name());
+      auto dims_name = global.constant<uint32_t>(dims);
 
       net_ctor.append("outputs.at(", n, ").name = ", name_var, ";");
+      net_ctor.append("outputs.at(", n, ").shape.rank = ", dims.size(), ";");
+      net_ctor.append("outputs.at(", n, ").shape.dims = ", dims_name, ";");
     }
 
     // TODO Implement this
index 5b9e134..b991eab 100644 (file)
@@ -14,4 +14,26 @@ std::string Global::constant(const std::string &s)
   return name;
 }
 
+template <> std::string Global::constant(const std::vector<uint32_t> &values)
+{
+  auto name = pp::fmt("g_", _count++);
+
+  std::stringstream ss;
+
+  ss << "const uint32_t " << name << "[" << values.size() << "] = { ";
+  if (values.size() > 0)
+  {
+    ss << values.at(0);
+    for (uint32_t n = 1; n < values.size(); ++n)
+    {
+      ss << ", " << values.at(n);
+    }
+  }
+  ss << " };";
+
+  _content.append(ss.str());
+
+  return name;
+}
+
 } // namespace enco
index f23a913..5dbdf52 100644 (file)
@@ -18,6 +18,9 @@ public:
   // @brief Create a global constant string (const char *) literal, and return variable name
   std::string constant(const std::string &value);
 
+  // @brief Create a global constant array variable of type T
+  template <typename T> std::string constant(const std::vector<T> &values);
+
 public:
   const pp::MultiLineText &content(void) const { return _content; }
 
diff --git a/contrib/enco/core/src/Dims.h b/contrib/enco/core/src/Dims.h
new file mode 100644 (file)
index 0000000..24fd4fa
--- /dev/null
@@ -0,0 +1,18 @@
+#ifndef __DIMS_H__
+#define __DIMS_H__
+
+#include <nncc/core/ADT/tensor/Shape.h>
+
+static inline std::vector<uint32_t> as_dims(const nncc::core::ADT::tensor::Shape &shape)
+{
+  std::vector<uint32_t> res;
+
+  for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+  {
+    res.emplace_back(shape.dim(axis));
+  }
+
+  return res;
+}
+
+#endif // __DIMS_H__