[nnc] Fix interpreter deconv implementation (#2756)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@samsung.com>
Wed, 26 Dec 2018 14:18:41 +0000 (17:18 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 26 Dec 2018 14:18:41 +0000 (17:18 +0300)
Rotate kernel index prior to input index calculation

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/passes/interpreter/ops/DeConv2D.cpp

index 750f967..e19e9e1 100644 (file)
@@ -55,12 +55,17 @@ std::vector<nnc::mir::TensorVariant> nnc::DeConv2D::operator()() {
     for (auto &kernel_idx_r : kernel_range) {
       auto kernel_idx = kernel_idx_r;
 
+      // rotate kernel 180 deg around last axis
+      // by index transform
+      for (int32_t d = 0; d < 2; ++d) {
+        kernel_idx.at(d) = kernel.getShape().dim(d) - kernel_idx.at(d) - 1;
+      }
+
       // flag that keeps info on whether the current input element is from input
       // or is from dilation by stride
       bool is_from_input = true;
       for (int32_t d = 1; d < input_idx.rank() - 1; ++d) {
-        const auto num = (out_idx.at(d) - kernel.getShape().dim(d - 1) + pads.at(d - 1) + 1 +
-                          kernel_idx.at(d - 1));
+        const auto num = (out_idx.at(d) + pads.at(d - 1) - kernel_idx.at(d - 1));
         const auto div_res = num / _strides.dim(d - 1);
         const auto rem = num % _strides.dim(d - 1);
         is_from_input = is_from_input && rem == 0;
@@ -72,12 +77,6 @@ std::vector<nnc::mir::TensorVariant> nnc::DeConv2D::operator()() {
       // channel index - same as kernel's
       input_idx.at(3) = kernel_idx.at(2);
 
-      // rotate kernel 180 deg around last axis
-      // by index transform
-      for (int32_t d = 0; d < 2; ++d) {
-        kernel_idx.at(d) = kernel.getShape().dim(d) - kernel_idx.at(d) -1;
-      }
-
       if (in_range.contains(input_idx) and is_from_input) {
         auto kernel_region = kernel.getRegion(kernel_idx);
         assert( kernel_region.size() == num_kernels );