[nnc] Conv transpose in SoftBackend (#2290)
authorАндрей Шедько/AI Tools Lab /SRR/Assistant Engineer/삼성전자 <a.shedko@partner.samsung.com>
Wed, 21 Nov 2018 17:42:29 +0000 (20:42 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 21 Nov 2018 17:42:29 +0000 (20:42 +0300)
Implements Deconvolution in Soft Backend, caffe style transfer model is fully supported.

Signed-off-by: Andrei Shedko a.shedko@partner.samsung.com
contrib/nnc/core/modelIR/ShapeInference.cpp
contrib/nnc/passes/interpreter/ops/DeConv2D.cpp
contrib/nnc/passes/soft_backend/CPPGenerator.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/code_snippets/cpp_common_funcs.def
contrib/nnc/passes/soft_backend/code_snippets/cpp_conv_transpose.def [new file with mode: 0644]
contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def
contrib/nnc/unittests/soft_backend/CPPOperations.cpp

index fd68663..3ef6069 100644 (file)
@@ -281,7 +281,6 @@ void ShapeInference::visit(ops::DeConv2DOp& op) {
   assert(in_shape.rank() == 3);
   assert(kernel_shape.dim(3) == in_shape.dim(2));
 
-
   auto pad_type = op.getPaddingType();
   auto in_rank = in_shape.rank();
   auto strides = op.getStrides();
index f57d479..c98429d 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "DeConv2D.h"
 #include "common.h"
+#include <iostream>
 
 namespace nnc {
 
@@ -39,6 +40,7 @@ std::vector<nnc::mir::TensorVariant> nnc::DeConv2D::operator()() {
   const Shape& in_shape = _input.getShape();
   ShapeRange in_range(_input.getShape());
 
+
   std::shared_ptr<TensorVariant> tr_kernel;
   const std::shared_ptr<const mir::TensorVariant> kernel_ptr(
           &_kernel, []( const TensorVariant* ){});
index 82e7519..7b2ea80 100644 (file)
@@ -32,6 +32,7 @@ using namespace std;
 #include "cpp_capped_relu.generated.h"
 #include "cpp_concat.generated.h"
 #include "cpp_conv.generated.h"
+#include "cpp_conv_transpose.generated.h"
 #include "cpp_depthwise_conv.generated.h"
 #include "cpp_fully_connected.generated.h"
 #include "cpp_pool.generated.h"
@@ -283,7 +284,7 @@ void CPPCodeGenerator::materializeCode(ostream &out, const ModelAnalyzer &ma, co
   out.write(cpp_elementwise, sizeof(cpp_elementwise));
   out.write(cpp_elu, sizeof(cpp_elu));
   out.write(cpp_tanh, sizeof(cpp_tanh));
-
+  out.write(cpp_conv_transpose, sizeof(cpp_conv_transpose));
   out.write(cpp_operations, sizeof(cpp_operations));
   out.write(cpp_scale, sizeof(cpp_scale));
   out.write(cpp_dropout, sizeof(cpp_dropout));
index 98bec79..34a093c 100644 (file)
@@ -238,7 +238,7 @@ void ModelAnalyzer::visit(ops::BatchNormOp& op) {
 }
 
 void ModelAnalyzer::visit(mir::ops::TanhOp& op) {
-  addOpDescr(&op, "tanh");
+  addOpDescr(&op, "tanhActivation");
 }
 
 void ModelAnalyzer::visit(mir::ops::ElementwiseOp& op) {
@@ -264,7 +264,7 @@ void ModelAnalyzer::visit(mir::ops::EluOp& op) {
 }
 
 void ModelAnalyzer::visit(mir::ops::DeConv2DOp& op) {
-  addOpDescr(&op, "transposedconv2d");
+  addOpDescr(&op, "convTransposed2d");
 }
 
 void ModelAnalyzer::visit(ops::SqueezeOp& op) {
index bb3f601..7c09d23 100644 (file)
@@ -300,7 +300,7 @@ void Serializer::visit(mir::ops::DeConv2DOp& op) {
   // serialize kernel
   shared_ptr<TensorVariant> HWCNKernel = make_shared<TensorVariant>(op.getKernel());
   // HWCN -> "IN"HW"OUT"
-  shared_ptr<TensorVariant> NHWCKernel = transposeTensor<3, 0, 1, 2>(HWCNKernel);
+  shared_ptr<TensorVariant> NHWCKernel = transposeTensor<2, 0, 1, 3>(HWCNKernel);
   serializeTensor(*NHWCKernel);
   // serialize strides
   serializeShape(op.getStrides());
index b4a17c9..6be873a 100644 (file)
@@ -94,6 +94,165 @@ struct Dims {
   int strides[N];
 };
 
+class RuntimeShape {
+public:
+  // Shapes with dimensions up to 4 are stored directly in the structure, while
+  // larger shapes are separately allocated.
+  static constexpr int kMaxSmallSize = 4;
+
+  RuntimeShape& operator=(RuntimeShape const&) = delete;
+
+  RuntimeShape() : size_(0) {}
+
+  explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
+    if (dimensions_count > kMaxSmallSize) {
+      dims_pointer_ = new int32[dimensions_count];
+    }
+  }
+
+  RuntimeShape(int shape_size, int32 value) : size_(0) {
+    Resize(shape_size);
+    for (int i = 0; i < shape_size; ++i) {
+      SetDim(i, value);
+    }
+  }
+
+  RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
+    ReplaceWith(dimensions_count, dims_data);
+  }
+
+  RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
+    BuildFrom(init_list);
+  }
+
+  // Avoid using this constructor.  We should be able to delete it when C++17
+  // rolls out.
+  RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
+    if (size_ > kMaxSmallSize) {
+      dims_pointer_ = new int32[size_];
+    }
+    std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
+  }
+
+  bool operator==(const RuntimeShape& comp) const {
+    return this->size_ == comp.size_ &&
+           std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
+  }
+
+  ~RuntimeShape() {
+    if (size_ > kMaxSmallSize) {
+
+      delete[] dims_pointer_;
+    }
+  }
+
+  inline int32 DimensionsCount() const { return size_; }
+  inline int32 Dims(int i) const {
+    TFLITE_DCHECK_GE(i, 0);
+    TFLITE_DCHECK_LT(i, size_);
+    return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
+  }
+  inline void SetDim(int i, int32 val) {
+    TFLITE_DCHECK_GE(i, 0);
+    TFLITE_DCHECK_LT(i, size_);
+    if (size_ > kMaxSmallSize) {
+      dims_pointer_[i] = val;
+    } else {
+      dims_[i] = val;
+    }
+  }
+
+  inline int32* DimsData() {
+    return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+  }
+  inline const int32* DimsData() const {
+    return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+  }
+  // The caller must ensure that the shape is no bigger than 4-D.
+  inline const int32* DimsDataUpTo4D() const { return dims_; }
+
+  inline void Resize(int dimensions_count) {
+    if (size_ > kMaxSmallSize) {
+      delete[] dims_pointer_;
+    }
+    size_ = dimensions_count;
+    if (dimensions_count > kMaxSmallSize) {
+      dims_pointer_ = new int32[dimensions_count];
+    }
+  }
+
+  inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
+    Resize(dimensions_count);
+    int32* dst_dims = DimsData();
+    std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
+  }
+
+  template <typename T>
+  inline void BuildFrom(const T& src_iterable) {
+    const int dimensions_count =
+      std::distance(src_iterable.begin(), src_iterable.end());
+    Resize(dimensions_count);
+    int32* data = DimsData();
+    for (auto it : src_iterable) {
+      *data = it;
+      ++data;
+    }
+  }
+
+  // This will probably be factored out. Old code made substantial use of 4-D
+  // shapes, and so this function is used to extend smaller shapes. Note that
+  // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
+  // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
+  // inputs should already be 4-D, so this function should not be needed.
+  inline static RuntimeShape ExtendedShape(int new_shape_size,
+                                           const RuntimeShape& shape) {
+    return RuntimeShape(new_shape_size, shape, 1);
+  }
+
+  inline void BuildFrom(const std::initializer_list<int> init_list) {
+    BuildFrom<const std::initializer_list<int>>(init_list);
+  }
+
+  // Returns the total count of elements, that is the size when flattened into a
+  // vector.
+  inline int FlatSize() const {
+    int buffer_size = 1;
+    const int* dims_data = DimsData();
+    for (int i = 0; i < size_; i++) {
+      const int dim = dims_data[i];
+      TFLITE_DCHECK_GE(dim, 1);
+      buffer_size *= dim;
+    }
+    return buffer_size;
+  }
+
+  bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
+
+private:
+  // For use only by ExtendedShape(), written to guarantee (return-value) copy
+  // elision in C++17.
+  // This creates a shape padded to the desired size with the specified value.
+  RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
+    : size_(0) {
+    // If the following check fails, it is likely because a 4D-only kernel is
+    // being used with an array of larger dimension count.
+    TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
+    Resize(new_shape_size);
+    const int size_increase = new_shape_size - shape.DimensionsCount();
+    for (int i = 0; i < size_increase; ++i) {
+      SetDim(i, pad_value);
+    }
+    std::memcpy(DimsData() + size_increase, shape.DimsData(),
+                sizeof(int32) * shape.DimensionsCount());
+  }
+
+  int32 size_;
+  union {
+    int32 dims_[kMaxSmallSize];
+    int32* dims_pointer_;
+  };
+};
+
 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
   TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
   TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
@@ -148,6 +307,70 @@ inline int FlatSize(const Dims<N>& dims) {
   return flat_size;
 }
 
+template <int N>
+inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
+  TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
+  int flat_size = 1;
+  for (int i = 0; i < N; ++i) {
+    flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
+  }
+  return flat_size;
+}
+
+// A combination of MatchingFlatSize() and FlatSizeSkipDim().
+template <int N>
+inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
+                                   const Dims<N>& check_dims_0) {
+  for (int i = 0; i < N; ++i) {
+    if (i != skip_dim) {
+      TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
+    }
+  }
+  return FlatSizeSkipDim(dims, skip_dim);
+}
+
+template <int N>
+inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
+                                   const Dims<N>& check_dims_0,
+                                   const Dims<N>& check_dims_1,
+                                   const Dims<N>& check_dims_2) {
+  for (int i = 0; i < N; ++i) {
+    if (i != skip_dim) {
+      TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
+    }
+  }
+  return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
+}
+
+template <int N>
+inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
+                                   const Dims<N>& check_dims_0,
+                                   const Dims<N>& check_dims_1,
+                                   const Dims<N>& check_dims_2,
+                                   const Dims<N>& check_dims_3) {
+  for (int i = 0; i < N; ++i) {
+    if (i != skip_dim) {
+      TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
+    }
+  }
+  return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
+                                 check_dims_3);
+}
+
+// Data is required to be contiguous, and so many operators can use either the
+// full array flat size or the flat size with one dimension skipped (commonly
+// the depth).
+inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
+  const int dims_count = shape.DimensionsCount();
+  TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
+  const auto* dims_data = shape.DimsData();
+  int flat_size = 1;
+  for (int i = 0; i < dims_count; ++i) {
+    flat_size *= (i == skip_dim) ? 1 : dims_data[i];
+  }
+  return flat_size;
+}
+
 // *****************************************************************************
 // From optimized_ops.h
 
@@ -193,6 +416,23 @@ MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
   return MatrixMap<Scalar>(data, rows, cols);
 }
 
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
+                                               const RuntimeShape& shape) {
+  const int dims_count = shape.DimensionsCount();
+  const int rows = shape.Dims(dims_count - 1);
+  const int cols = FlatSizeSkipDim(shape, dims_count - 1);
+  return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
+                                                const RuntimeShape& shape) {
+  const int cols = shape.Dims(0);
+  const int rows = FlatSizeSkipDim(shape, 0);
+  return MatrixMap<Scalar>(data, rows, cols);
+}
+
 template <typename Scalar, int N>
 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
                                                    const Dims<N>& dims,
@@ -221,3 +461,67 @@ void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
     result->noalias() = lhs * rhs;
   }
 }
+
+// Get common shape dim, DCHECKing that they all agree.
+inline int MatchingDim(const RuntimeShape& shape1, int index1,
+                       const RuntimeShape& shape2, int index2) {
+  TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+  return shape1.Dims(index1);
+}
+
+template <typename... Args>
+int MatchingDim(const RuntimeShape& shape1, int index1,
+                const RuntimeShape& shape2, int index2, Args... args) {
+  TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+  return MatchingDim(shape1, index1, args...);
+}
+
+enum class PaddingType : uint8 { kNone, kSame, kValid };
+
+struct PaddingValues {
+  int16 width;
+  int16 height;
+};
+
+struct ConvParams {
+  PaddingType padding_type;
+  PaddingValues padding_values;
+  // TODO(starka): This was just "stride", so check that width+height is OK.
+  int16 stride_width;
+  int16 stride_height;
+  /* not used currently
+  int16 dilation_width_factor;
+  int16 dilation_height_factor;
+  // uint8 inference params.
+  // TODO(b/65838351): Use smaller types if appropriate.
+  int32 input_offset;
+  int32 weights_offset;
+  int32 output_offset;
+  int32 output_multiplier;
+  int output_shift;
+  // uint8, etc, activation params.
+  int32 quantized_activation_min;
+  int32 quantized_activation_max;
+  // float activation params.
+  float float_activation_min;
+  float float_activation_max;
+   */
+};
+
+inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
+  TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
+  const int* dims_data = shape.DimsDataUpTo4D();
+  TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
+  TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
+  TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
+  TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
+  return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
+}
+
+inline int Offset(const Dims<4>& dims, int* index) {
+  return Offset(dims, index[0], index[1], index[2], index[3]);
+}
+
+inline int Offset(const RuntimeShape& shape, int* index) {
+  return Offset(shape, index[0], index[1], index[2], index[3]);
+}
diff --git a/contrib/nnc/passes/soft_backend/code_snippets/cpp_conv_transpose.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_conv_transpose.def
new file mode 100644 (file)
index 0000000..9b2aebb
--- /dev/null
@@ -0,0 +1,111 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+
+template <typename T>
+void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
+                     const RuntimeShape& input_shape, const T* input_data,
+                     const RuntimeShape& filter_shape,
+                     const RuntimeShape& output_shape, T* im2col_data) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK(im2col_data);
+
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  MatchingDim(output_shape, 3, filter_shape, 0);  // output_depth
+
+  // Construct the MxN sized im2col matrix.
+  // The rows M, are sub-ordered B x H x W
+  const RuntimeShape row_shape({1, batches, output_height, output_width});
+  // The columns, N, are sub-ordered Kh x Kw x Din
+  const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
+  // Use dimensions M and N to construct dims for indexing directly into im2col
+  const RuntimeShape im2col_shape(
+    {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
+
+  // Build the im2col matrix by looping through all the input pixels,
+  // computing their influence on the output, rather than looping through all
+  // the output pixels. We therefore must initialize the im2col array to zero.
+  // This is potentially inefficient because we subsequently overwrite bytes
+  // set here. However, in practice memset is very fast and costs negligible.
+  memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
+
+  // Loop through the output batches
+  for (int batch = 0; batch < batches; ++batch) {
+    // Loop through input pixels one at a time.
+    for (int in_y = 0; in_y < input_height; ++in_y) {
+      for (int in_x = 0; in_x < input_width; ++in_x) {
+        // Loop through the output pixels it will influence
+        const int out_x_origin = (in_x * stride_width) - pad_width;
+        const int out_y_origin = (in_y * stride_height) - pad_height;
+        for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+          const int out_y = out_y_origin + filter_y;
+          // Is output pixel within height bounds?
+          if ((out_y >= 0) && (out_y < output_height)) {
+            for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+              const int out_x = out_x_origin + filter_x;
+              // Is output pixel within width bounds?
+              if ((out_x >= 0) && (out_x < output_width)) {
+                // Copy the input elements of this pixel
+                T const* src =
+                  input_data + Offset(input_shape, batch, in_y, in_x, 0);
+                int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
+                int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
+                T* dst = im2col_data +
+                         Offset(im2col_shape, 0, 0, row_offset, col_offset);
+                memcpy(dst, src, input_depth * sizeof(T));
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+inline void TransposeConv(
+  const ConvParams& params, const RuntimeShape& input_shape,
+  const float* input_data, const RuntimeShape& filter_shape,
+  const float* filter_data, const RuntimeShape& output_shape,
+  float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
+
+  // Note we could use transposed weights with forward conv for unstrided
+  // cases. But we are already getting good performance with this code as-is.
+  TFLITE_DCHECK(im2col_data);
+  TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
+                  output_shape, im2col_data);
+
+  const auto im2col_matrix_map =
+    MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
+  const auto filter_matrix_map =
+    MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
+  auto output_matrix_map =
+    MapAsMatrixWithLastDimAsRows(output_data, output_shape);
+
+  Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+}
\ No newline at end of file
index 3a059d6..4c37889 100644 (file)
@@ -22,6 +22,7 @@
 #include <fcntl.h>
 #include <unistd.h>
 #include <cstring>
+#include <iostream>
 
 using namespace std;
 
@@ -94,6 +95,22 @@ size_t volume(Dims<rank> d)
   return v;
 }
 
+inline RuntimeShape shapeToRuntimeShape(const Shape &s) {
+  const int rank = s.getDims();
+  RuntimeShape sh(rank);
+  for (int i = 0; i < rank; i++) {
+    sh.SetDim(i,s[i]);
+  }
+  return sh;
+}
+
+inline RuntimeShape shapeToRuntimeShapePad4(const Shape &s) {
+  assert(s.getDims()==3);
+  RuntimeShape sh({1,(int)s[0],(int)s[1],(int)s[2]});
+  return sh;
+}
+
+
 Dims<4> shapeToDims(const Shape &s)
 {
   Dims<4> dims;
@@ -142,6 +159,12 @@ struct Kernel
   Dims<4> dims;
 };
 
+struct KernelRT
+{
+  RuntimeShape shape;
+  const float *data;
+};
+
 __attribute__((unused))
 static bool isAddrAligned(const void *data, int alignment)
 {
@@ -164,6 +187,24 @@ static inline Kernel deserializeKernel(const char *&buf)
   return k;
 }
 
+static inline KernelRT deserializeKernelRT(const char *&buf)
+{
+  int32_t dType = deserializeT<int32_t>(buf);
+  assert(dType == 1 && "Unknown data type");
+  UNUSED(dType);
+  int32_t eSize = deserializeT<int32_t>(buf);
+  assert(eSize == 4 && "Unsupported element size");
+  UNUSED(eSize);
+  KernelRT k={
+    shapeToRuntimeShape(deserializeShape(buf)),
+    reinterpret_cast<const float *>(buf)
+  };
+
+  assert(isAddrAligned(buf, 4) && "data should be aligned to 4 bytes to use arm vector instructions");
+  buf += k.shape.FlatSize() * eSize;
+  return k;
+}
+
 // This operation takes as input multiple tensors, at least 2, likely less then 7
 // parameter pack provides generalization for all possible number of inputs
 template <class ...Args>
@@ -232,6 +273,46 @@ void conv2d(Tensor &out, const char *params, const Tensor &in)
        im2col.get(), im2col_d);
 }
 
+void convTransposed2d(Tensor &out, const char *params, const Tensor &in) {
+  const float *input = in.getData();
+  RuntimeShape input_shape = shapeToRuntimeShapePad4(in.getShape());
+  KernelRT kernel = deserializeKernelRT(params);
+  Shape strides = deserializeShape(params);
+  // pads type. unused for now
+  int32_t pt = deserializeT<int32_t>(params);
+  UNUSED(pt);
+  Shape pads = deserializeShape(params);
+  Shape out_s = deserializeShape(params);
+
+  out.reShape(out_s);
+
+  RuntimeShape out_shape = shapeToRuntimeShapePad4(out_s);
+
+  const short stride_w = strides[1];
+  const short stride_h = strides[0];
+  assert(strides[2] == 1);
+  const short pad_w = pads[1];
+  const short pad_h = pads[0];
+
+  const int kw = kernel.shape.Dims(2);
+  const int kh = kernel.shape.Dims(1);
+
+  RuntimeShape im2col_shape = RuntimeShape({1,1,(int) (out_s[0]*out_s[1]),
+                                       input_shape.Dims(3)*kw*kh});
+
+  const auto convPara = ConvParams({PaddingType::kSame,
+                                    PaddingValues({pad_w,pad_h}), stride_w, stride_h});
+
+  unique_ptr<float, void(*)(float *)> im2col(nullptr, [](float *d){delete [] d;});
+  if (stride_w != 1 || stride_h != 1 || kernel.shape.Dims(1) != 1 || kernel.shape.Dims(2) != 1) {
+    im2col.reset(new float[im2col_shape.FlatSize()]);
+  }
+
+  TransposeConv(
+    convPara, input_shape, input, kernel.shape, kernel.data,
+    out_shape, out.getData(), im2col_shape, im2col.get());
+}
+
 void depthwiseConv2d(Tensor &out, const char *params, const Tensor &in)
 {
   const float *input = in.getData();
index f0b20f7..6b82e0a 100644 (file)
@@ -30,6 +30,7 @@
 #include "code_snippets/cpp_capped_relu.def"
 #include "code_snippets/cpp_concat.def"
 #include "code_snippets/cpp_conv.def"
+#include "code_snippets/cpp_conv_transpose.def"
 #include "code_snippets/cpp_depthwise_conv.def"
 #include "code_snippets/cpp_fully_connected.def"
 #include "code_snippets/cpp_pool.def"
@@ -444,6 +445,37 @@ TEST(cpp_operations_test, max4) {
   }
 }
 
+TEST(cpp_operations_test, convTransposed2d)
+{
+  // Iterate over kernel width, kernel height,
+  // input channels(inputC), output channels(outputC),
+  // stride width, stride height
+  // size 3 is chosen to cover all cases, where width bigger/smaller then height and equal/not equal to 1
+  using iT = int32_t;
+  for (iT kernelH = 2; kernelH <= 4; ++kernelH)
+    for (iT kernelW = 2; kernelW <= 4; ++kernelW)
+      for (iT inputC = 1; inputC <= 3; ++inputC)
+        for (iT outputC = 1; outputC <= 3; ++outputC)
+          for (iT strideH = 1; strideH <= 3; ++strideH)
+            for (iT strideW = 1; strideW <= 3; ++strideW) {
+              vector<int> inputShapeData{9, 3, static_cast<int>(inputC)};  // HWC
+              mir::Shape kernelShape{kernelH, kernelW, outputC, inputC};
+              mir::Shape strides{strideH, strideW, 1};
+              vector<unique_ptr<mir::TensorVariant>> inputNTensors(1);
+              Tensor aInputTensor;
+              fillTensors(inputNTensors[0], aInputTensor, inputShapeData, 1.0f);
+              auto padT = mir::ops::PaddingType::Same;
+              mir::TensorVariant kernel = createNTensor(kernelShape, 1.0f);
+              auto opGenerator = [kernel, strides, padT](mir::Graph &g)
+              {
+                return g.create<mir::ops::DeConv2DOp>("y", kernel, strides, padT);
+              };
+
+              createAndRunTestGraph(opGenerator, convTransposed2d, inputNTensors, aInputTensor);
+            }
+}
+
+
 TEST(cpp_operations_test, conv2d)
 {
   // Iterate over kernel width, kernel height,