[nnc backend] Add Concat implementation (#419)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Tue, 10 Jul 2018 07:27:50 +0000 (10:27 +0300)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Tue, 10 Jul 2018 07:27:50 +0000 (16:27 +0900)
[nnc backend] Add Concat implementation

Reference concat implementation used by model IR interpreter

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/Concat.h [new file with mode: 0644]
contrib/nnc/libs/backend/interpreter/core/src/ops/Concat.cpp [new file with mode: 0644]

diff --git a/contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/Concat.h b/contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/Concat.h
new file mode 100644 (file)
index 0000000..6a276f0
--- /dev/null
@@ -0,0 +1,60 @@
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_FILL_IMPL_
+#define _NNC_CORE_BACKEND_INTERPRETER_FILL_IMPL_
+
+#include "interpreter/ops/Fill.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace backend
+{
+namespace interpreter
+{
+namespace impl
+{
+
+template <typename T> class Concat : public Fill<T>
+{
+public:
+  explicit Concat(const std::vector<TensorVariant> &inputs, const Shape &outputShape,
+                  unsigned int axis)
+      : Fill<T>(outputShape, getSingleFunction(inputs, axis))
+  {
+  }
+
+private:
+  const std::function<T(const Index &)> getSingleFunction(const std::vector<TensorVariant> &inputs,
+                                                          unsigned int axis)
+  {
+    std::vector<Tensor<T>> inputAccessors;
+    for (auto &in : inputs)
+    {
+      inputAccessors.emplace_back(in);
+    }
+
+    return std::function<T(const Index &)>([inputAccessors, axis](const Index &id) -> T {
+      unsigned int mi = 0;
+      uint32_t along_axis = id.at(axis);
+
+      while (along_axis >= inputAccessors.at(mi).getShape().dim(axis))
+      {
+        along_axis -= inputAccessors[mi].getShape().dim(axis);
+        mi++;
+      }
+
+      Index local_id = id;
+      local_id.at(axis) = along_axis;
+
+      return inputAccessors[mi].at(local_id);
+    });
+  }
+};
+
+} // namespace impl
+} // namespace interpreter
+} // namespace backend
+} // namespace contrib
+} // namespace nncc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_FILL_IMPL_
diff --git a/contrib/nnc/libs/backend/interpreter/core/src/ops/Concat.cpp b/contrib/nnc/libs/backend/interpreter/core/src/ops/Concat.cpp
new file mode 100644 (file)
index 0000000..13991a0
--- /dev/null
@@ -0,0 +1 @@
+#include "interpreter/ops/Concat.h"