[nnc backend] Add DeptwiseConv2D operation implementation (#418)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Tue, 10 Jul 2018 07:43:01 +0000 (10:43 +0300)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Tue, 10 Jul 2018 07:43:01 +0000 (16:43 +0900)
[nnc backend] Add DeptwiseConv2D operation implementation

Add DepthwiseConv2D reference implementation used by model IR interpreter

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/Depthwise_conv_2D.h [new file with mode: 0644]
contrib/nnc/libs/backend/interpreter/core/src/ops/Depthwise_conv_2D.cpp [new file with mode: 0644]

diff --git a/contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/Depthwise_conv_2D.h b/contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/Depthwise_conv_2D.h
new file mode 100644 (file)
index 0000000..f305415
--- /dev/null
@@ -0,0 +1,44 @@
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_DEPTHWISE_CONV2D_IMPL_
+#define _NNC_CORE_BACKEND_INTERPRETER_DEPTHWISE_CONV2D_IMPL_
+
+#include "interpreter/ops/OperationImpl.h"
+
+#include "nnc/core/IR/model/operations/common.h"
+#include "nnc/core/IR/model/operations/depthwise_conv2d_op.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace backend
+{
+namespace interpreter
+{
+namespace impl
+{
+
+using nncc::contrib::core::IR::model::ops::DepthwiseConv2DOp;
+using nncc::contrib::core::IR::model::ops::PaddingType;
+
+class DepthwiseConv2D : public OperationImpl<float>
+{
+public:
+  explicit DepthwiseConv2D(const TensorVariant &input, const DepthwiseConv2DOp &op);
+  virtual std::vector<TensorVariant> operator()() override;
+
+private:
+  const Tensor<float> _input;
+  const Tensor<float> _kernel;
+  const Shape _strides;
+  const PaddingType _padding;
+  const Shape &_out_shape;
+  const DepthwiseConv2DOp &_op;
+};
+
+} // namespace impl
+} // namespace interpreter
+} // namespace backend
+} // namespace contrib
+} // namespace nncc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_DEPTHWISE_CONV2D_IMPL_
diff --git a/contrib/nnc/libs/backend/interpreter/core/src/ops/Depthwise_conv_2D.cpp b/contrib/nnc/libs/backend/interpreter/core/src/ops/Depthwise_conv_2D.cpp
new file mode 100644 (file)
index 0000000..a9816f0
--- /dev/null
@@ -0,0 +1,71 @@
+#include "nnc/core/linalg/ShapeRange.h"
+
+#include "interpreter/ops/Depthwise_conv_2D.h"
+#include "interpreter/ops/common.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace backend
+{
+namespace interpreter
+{
+namespace impl
+{
+
+std::vector<TensorVariant> DepthwiseConv2D::operator()()
+{
+  TensorVariant res = allocate_tensor(_out_shape);
+  Tensor<float> resAccessor(res);
+
+  Shape strides({_strides.dim(0), _strides.dim(1), _strides.dim(2)});
+  Index pads({(uint32_t)_op.getPadding(0), (uint32_t)_op.getPadding(1), 0u});
+
+  Shape outShape = res.getShape();
+  outShape.dim(2) = 1;
+  ShapeRange outRange(outShape);
+
+  ShapeRange inRange(_input.getShape());
+
+  Index inIdx;
+  inIdx.resize(outShape.rank());
+
+  for (auto &outIdx : outRange)
+  {
+    for (auto &kIdx : ShapeRange(_kernel.getShape()))
+    {
+      translate(inIdx, outIdx, kIdx, strides, pads);
+
+      if (inRange.contains(inIdx))
+      {
+        auto in = _input.at(inIdx);
+        auto b = _kernel.at(kIdx);
+        Index outIdxK = outIdx;
+        outIdxK.at(2) = kIdx.at(2);
+        resAccessor.at(outIdxK) += in * b;
+      }
+    }
+  }
+
+  return {res};
+}
+
+DepthwiseConv2D::DepthwiseConv2D(const TensorVariant &input, const DepthwiseConv2DOp &op)
+    : _input(input), _kernel(op.getKernel()), _strides(op.getStrides()),
+      _padding(op.getPaddingType()), _out_shape(op.getOutputShape(0)), _op(op)
+{
+
+  assert(_op.getInputShape(0).rank() == 3);
+  assert(input.getShape().rank() == 3);
+  assert(_kernel.getShape().rank() == 3);
+  assert(_strides.dim(2) == 1);
+  assert(_op.getPadding(2) == 0);
+  assert(_kernel.getShape().dim(2) == _input.getShape().dim(2));
+}
+
+} // namespace impl
+} // namespace interpreter
+} // namespace backend
+} // namespace contrib
+} // namespace nncc