[nnc] Initial implementation of caffe2_op_creator (#2333)
authorIvan Vagin/AI Tools Lab /SRR/Engineer/삼성전자 <ivan.vagin@samsung.com>
Fri, 7 Dec 2018 11:23:02 +0000 (14:23 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 7 Dec 2018 11:23:02 +0000 (14:23 +0300)
Initial implementation of caffe2_op_creator.

`mobilenet` supported, to support `inception` model - need to support custom paddings in pooling ops and test not tested operation conversations.

Implemented ops:
- Add
- AveragePool
- Conv
- Concat
- Dropout
- FC
- GivenTensorFill
- MaxPool
- Mul
- Relu
- Softmax
- SpatialBN
- Sum

Not tested ops:
- Add
- Concat
- Mul
- SpatialBN

Signed-off-by: Ivan Vagin <ivan.vagin@samsung.com>
12 files changed:
contrib/nnc/driver/Options.cpp
contrib/nnc/include/passes/caffe2_frontend/caffe2_importer.h
contrib/nnc/include/passes/common_frontend/op_creator_helper.h [new file with mode: 0644]
contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp
contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp [new file with mode: 0644]
contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h [new file with mode: 0644]
contrib/nnc/passes/caffe2_frontend/caffe2_op_types.h
contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.cpp
contrib/nnc/passes/caffe2_frontend/caffe2_proto_helper.h
contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp
contrib/nnc/passes/common_frontend/CMakeLists.txt
contrib/nnc/passes/common_frontend/op_creator_helper.cpp [new file with mode: 0644]

index 317dc48..028bd67 100644 (file)
@@ -92,7 +92,7 @@ Option<std::string> initNet(optname("--init-net"),
                             std::string(),
                             optional(false),
                             optvalues(""),
-                            nullptr,
+                            checkInFile,
                             separators(""),
 #ifdef NNC_FRONTEND_CAFFE2_ENABLED
                             showopt(true),
index b9fa140..db637ce 100644 (file)
@@ -33,7 +33,7 @@ class OperatorDef;
 class NetDef;
 }
 namespace nnc {
-// class Caffe2OpCreator;
+class Caffe2OpCreator;
 enum class SupportedCaffe2OpType : uint8_t;
 }
 
@@ -68,7 +68,7 @@ private:
   std::string _initNet;
   mir::Graph* _graph;
   std::unique_ptr<::caffe2::NetDef> _net;
-  // std::unique_ptr<Caffe2OpCreator> _opCreator;
+  std::unique_ptr<Caffe2OpCreator> _opCreator;
   std::vector<mir::Shape> _inputShapes;
 
   static const std::map<std::string, SupportedCaffe2OpType> _operatorTypes;
@@ -77,7 +77,7 @@ private:
   // This map maps caffe2 operators names to MIR operators
   // that correspond to previous caffe2 operators
   std::map<std::string, mir::IODescriptor> _blobNameToIODescriptor;
-  mir::Operation* _lastNode;
+  mir::Operation* _lastMIROp;
 
   std::map<std::string, std::shared_ptr<MIRTensor>> _MIRTensors;
 
@@ -85,47 +85,42 @@ private:
   * @brief Pass through caffe2 graph and collect ops unsupported by NNC
   * @throw PassException with message, containing detected problems
   */
-  // void collectUnsupportedOps();
+  void collectUnsupportedOps();
 
   /**
   * @brief Collecting unsupported parts of caffe2 operator
   */
-  // void collectUnsupportedOp(const ::caffe2::OperatorDef&);
+  void collectUnsupportedOp(const ::caffe2::OperatorDef&);
 
   /**
   * @brief Creating MIR node from single caffe2 operator
   */
-  // void createMIRNodesFromOp(const ::caffe2::OperatorDef&);
+  void createMIRNodesFromOp(const ::caffe2::OperatorDef&);
 
   /**
   * @brief Since caffe2 tensor values stored separately (in init_net) - preload them in _MIRTensors
   */
-  // void preloadAllTensors();
+  void preloadAllTensors();
 
   /**
   * @brief Creates MIR tensor from caffe2 givenTensorFill op
   */
-  // std::shared_ptr<mir::TensorVariant> createTensor(const ::caffe2::OperatorDef&);
+  std::shared_ptr<mir::TensorVariant> createTensor(const ::caffe2::OperatorDef&);
 
   /**
   * @brief Returns MIR ops, under given caffe2 op
   */
-  // std::vector<mir::IODescriptor> getInputMIROps(const ::caffe2::OperatorDef&);
-
-  /**
-  * @brief create MIR inputs with given names and shapes
-  */
-  // void createGraphInputs(const std::vector<std::string>&, const std::vector<mir::Shape>&);
+  std::vector<mir::IODescriptor> getInputMIROps(const ::caffe2::OperatorDef&);
 
   /**
   * @brief Mark output MIR nodes
   */
-  // void setGraphOutputs();
+  void setGraphOutputs();
 
   /**
   * @brief Set MIR node names
   */
-  // void setIrNodeNames();
+  void setIrNodeNames();
 };
 
 } // namespace nnc
diff --git a/contrib/nnc/include/passes/common_frontend/op_creator_helper.h b/contrib/nnc/include/passes/common_frontend/op_creator_helper.h
new file mode 100644 (file)
index 0000000..11309b9
--- /dev/null
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FRONTEND_COMMON_OP_CREATOR_HELPER_H_
+#define FRONTEND_COMMON_OP_CREATOR_HELPER_H_
+
+#include <cstdint>
+#include <memory>
+
+#include "core/modelIR/Shape.h"
+#include "core/modelIR/TensorVariant.h"
+
+namespace nnc {
+
+/** Convert kernel for grouped 2d convolution in kernel for ordinary 2d convolution
+ *
+ * Grouped convolution breaks input and kernel channels into selected number of groups and applies convolution in every group of channels independently.
+ * This technique allows to save kernel size(channels from different groups are not merged, no need to store redundant 0 weights).
+ * This is not supported by compiler for now, so this function unfolds compact kernel into classic flavored "every input layer affects every output layer",
+ * by inserting zero coefficients where needed
+ *
+ * @param groups number of groups in grouped convolution
+ * @param foldedKernel original grouped kernel
+ * @return unfolded kernel, compatible with ordinary conv2D operation
+ */
+std::shared_ptr<mir::TensorVariant>
+fixGroupedKernel(int groups, std::shared_ptr<mir::TensorVariant> folded_kernel);
+
+} // namespace nnc
+
+#endif // FRONTEND_COMMON_OP_CREATOR_HELPER_H_
index 69bc7d8..d9673ae 100644 (file)
@@ -26,7 +26,7 @@
 #include "caffe2/proto/caffe2.pb.h"
 
 #include "caffe2_op_types.h"
-// #include "caffe2_op_creator.h"
+#include "caffe2_op_creator.h"
 
 #include "core/modelIR/Shape.h"
 #include "core/modelIR/operations/VariableOp.h"
@@ -44,13 +44,13 @@ Caffe2Importer::Caffe2Importer(std::string predictNet, std::string initNet,
                                std::vector<std::vector<int>> shapes) :
         _predictNet(std::move(predictNet)),
         _initNet(std::move(initNet)),
-        _graph(new mir::Graph())/*,
-        _opCreator(new Caffe2OpCreator(_graph))*/ {
-  for(auto& shape : shapes)
+        _graph(new mir::Graph()),
+        _opCreator(new Caffe2OpCreator(_graph)) {
+  for (auto& shape : shapes)
     _inputShapes.emplace_back(shape);
 }
 
-Caffe2Importer::~Caffe2Importer()=default;
+Caffe2Importer::~Caffe2Importer() = default;
 
 PassData Caffe2Importer::run(PassData) {
   import();
@@ -66,41 +66,240 @@ void Caffe2Importer::import() {
 
   _net.reset(new NetDef());
   if (!readProtoFromBinaryFile<::caffe2::NetDef>(_predictNet.c_str(), _net.get()))
-    throw PassException("Could not load model: " + _predictNet+ "\n");
+    throw PassException("Could not load model: " + _predictNet + "\n");
 
   std::unique_ptr<NetDef> net2;
   net2.reset(new NetDef());
   if (!readProtoFromBinaryFile<::caffe2::NetDef>(_initNet.c_str(), net2.get()))
-    throw PassException("Could not load model: " + _initNet+ "\n");
+    throw PassException("Could not load model: " + _initNet + "\n");
   _net->MergeFrom(*net2);
 
-  // collectUnsupportedOps();
+  preloadAllTensors();
 
-  // preloadAllTensors();
+  collectUnsupportedOps();
 }
 
 mir::Graph* Caffe2Importer::createIR() {
-  throw PassException("Caffe2: NYI");
-  /*
   for (auto& op : _net->op())
     createMIRNodesFromOp(op);
 
   setIrNodeNames();
   setGraphOutputs();
-  */
 
   return _graph;
 }
 
+void Caffe2Importer::collectUnsupportedOps() {
+  for (auto& op : _net->op())
+    collectUnsupportedOp(op);
+
+  if (!_problemsOpSet.empty()) {
+    std::string msg("Detected problems:\n");
+    for (const auto& problemStr : _problemsOpSet)
+      msg.append(problemStr + "\n");
+    throw PassException(msg);
+  }
+}
+
+void Caffe2Importer::collectUnsupportedOp(const OperatorDef& op) {
+  if (_operatorTypes.find(op.type()) == _operatorTypes.end()) {
+    _problemsOpSet.insert(op.type() + ": unknown layer");
+    return;
+  }
+
+  SupportedCaffe2OpType opType = _operatorTypes.at(op.type());
+  switch (opType) {
+    case SupportedCaffe2OpType::FC:
+      _opCreator->checkFC(op, _problemsOpSet);
+      break;
+    case SupportedCaffe2OpType::spatialBN:
+      _opCreator->checkSpatialBN(op, _problemsOpSet);
+      break;
+    case SupportedCaffe2OpType::add:
+    case SupportedCaffe2OpType::averagePool:
+    case SupportedCaffe2OpType::concat:
+    case SupportedCaffe2OpType::constantFill:
+    case SupportedCaffe2OpType::conv:
+    case SupportedCaffe2OpType::dropout:
+    case SupportedCaffe2OpType::givenTensorFill:
+    case SupportedCaffe2OpType::maxPool:
+    case SupportedCaffe2OpType::mul:
+    case SupportedCaffe2OpType::relu:
+    case SupportedCaffe2OpType::softmax:
+    case SupportedCaffe2OpType::sum:
+      _opCreator->commonCheck(op, _problemsOpSet);
+      break;
+    default:
+      _problemsOpSet.insert(op.type() + ": unsupported layer");
+      break;
+  }
+}
+
+void Caffe2Importer::preloadAllTensors() {
+  for (auto& op : _net->op()) {
+    // All tensor values are stored in 'GivenTensorFill' and 'ConstantFill' operators, so skip rest
+    auto opType = _operatorTypes.at(op.type());
+    if ((opType == SupportedCaffe2OpType::givenTensorFill
+         || opType == SupportedCaffe2OpType::constantFill)
+        && hasArgument(op.arg(), "values")) {
+      _MIRTensors.insert(
+              std::pair<std::string, std::shared_ptr<MIRTensor>>(op.output(0), createTensor(op)));
+    }
+  }
+}
+
+void Caffe2Importer::createMIRNodesFromOp(const OperatorDef& op) {
+  std::vector<mir::IODescriptor> outputs;
+
+  // If op input not met yet - consider it as model input
+  if (op.input_size() > 0
+      && _blobNameToIODescriptor.find(op.input(0)) == _blobNameToIODescriptor.end()) {
+
+    outputs = _opCreator->createInput(op.input(0), _inputShapes.front());
+    _blobNameToIODescriptor[op.input(0)] = outputs.at(0);
+
+    _inputShapes.erase(_inputShapes.begin(), _inputShapes.begin() + 1);
+  }
+
+  auto inputs = getInputMIROps(op);
+
+  SupportedCaffe2OpType opType = _operatorTypes.at(op.type());
+  switch (opType) {
+    case SupportedCaffe2OpType::constantFill:
+    case SupportedCaffe2OpType::givenTensorFill:
+      return;
+    case SupportedCaffe2OpType::add:
+      outputs = _opCreator->convertAdd(inputs, op, _MIRTensors);
+      break;
+    case SupportedCaffe2OpType::averagePool:
+      outputs = _opCreator->convertAveragePool(inputs, op);
+      break;
+    case SupportedCaffe2OpType::conv:
+      outputs = _opCreator->convertConv(inputs, op, _MIRTensors);
+      break;
+    case SupportedCaffe2OpType::concat:
+      outputs = _opCreator->convertConcat(inputs, op);
+      break;
+    case SupportedCaffe2OpType::dropout:
+      outputs = _opCreator->convertDropout(inputs, op);
+      break;
+    case SupportedCaffe2OpType::FC:
+      outputs = _opCreator->convertFullyConnected(inputs, op, _MIRTensors);
+      break;
+    case SupportedCaffe2OpType::maxPool:
+      outputs = _opCreator->convertMaxPool(inputs, op);
+      break;
+    case SupportedCaffe2OpType::mul:
+      outputs = _opCreator->convertMul(inputs, op, _MIRTensors);
+      break;
+    case SupportedCaffe2OpType::relu:
+      outputs = _opCreator->convertRelu(inputs);
+      break;
+    case SupportedCaffe2OpType::softmax:
+      outputs = _opCreator->convertSoftmax(inputs, op);
+      break;
+    case SupportedCaffe2OpType::spatialBN:
+      outputs = _opCreator->convertSpatialBN(inputs, op, _MIRTensors);
+      break;
+    case SupportedCaffe2OpType::sum:
+      outputs = _opCreator->convertSum(inputs);
+      break;
+    default:
+      assert(false && "All unsupported types should have been found before this pass.");
+  }
+
+  for (int i = 0; i < outputs.size(); ++i) {
+    // caffe2 input blob name could be same as output blob name, and next line will overwrite
+    // '_blobNameToIODescriptor' element, but in all networks that I saw it was not a problem
+    _blobNameToIODescriptor[op.output(i)] = outputs.at(i);
+  }
+
+  _lastMIROp = outputs.at(0).op;
+}
+
+std::shared_ptr<IrTensor> Caffe2Importer::createTensor(const OperatorDef& op) {
+  assert(hasArgument(op.arg(), "shape") && hasArgument(op.arg(), "values"));
+
+  auto shape = findArgumentByName(op.arg(), "shape");
+  auto values = findArgumentByName(op.arg(), "values");
+
+  // Create untyped tensor. Note, tensor contents will be *copied* here.
+  auto type = mir::DTYPE::FLOAT32;
+  size_t elementSize = sizeof(float);
+  size_t bufferSize = values.floats().size() * elementSize;
+  const char* srcData = reinterpret_cast<const char*>(values.floats().data());
+  std::shared_ptr<char> tensorBufferCopy(new char[bufferSize],
+                                         std::default_delete<char[]>());
+  char* dstData = tensorBufferCopy.get();
+  memcpy(dstData, srcData, bufferSize);
+
+  Shape tensor_shape = ShapeHelper::createShape(
+          shape.ints(), static_cast<size_t>(shape.ints().size()));
+
+  auto tensor = std::make_shared<IrTensor>(tensor_shape, tensorBufferCopy, type, elementSize);
+
+  return tensor;
+}
+
+std::vector<mir::IODescriptor> Caffe2Importer::getInputMIROps(const OperatorDef& op) {
+  // caffe2 operation inputs not same as MIR inputs (ex: in caffe2 conv kernel and bias also inputs)
+  // so choose caffe2 inputs, which are 'real' inputs
+  std::vector<mir::IODescriptor> inputs;
+  SupportedCaffe2OpType opType = _operatorTypes.at(op.type());
+  switch (opType) {
+    case SupportedCaffe2OpType::givenTensorFill:
+    case SupportedCaffe2OpType::constantFill:
+      break;
+    case SupportedCaffe2OpType::add:
+    case SupportedCaffe2OpType::averagePool:
+    case SupportedCaffe2OpType::conv:
+    case SupportedCaffe2OpType::dropout:
+    case SupportedCaffe2OpType::FC:
+    case SupportedCaffe2OpType::maxPool:
+    case SupportedCaffe2OpType::mul:
+    case SupportedCaffe2OpType::relu:
+    case SupportedCaffe2OpType::softmax:
+    case SupportedCaffe2OpType::spatialBN:
+      inputs.push_back(_blobNameToIODescriptor[op.input(0)]);
+      break;
+    case SupportedCaffe2OpType::sum:
+    case SupportedCaffe2OpType::concat:
+      for (auto& i : op.input())
+        inputs.push_back(_blobNameToIODescriptor[i]);
+      break;
+    default:
+      assert(false && "All unsupported types should have been found before this pass.");
+  }
+
+  return inputs;
+}
+
+void Caffe2Importer::setGraphOutputs() {
+  // For now, we assume that:
+  //   - there is exactly one output;
+  //   - the output is from the last layer.
+  _graph->markOutput(_lastMIROp);
+}
+
+void Caffe2Importer::setIrNodeNames() {
+  for (auto& item : _blobNameToIODescriptor)
+    item.second.op->setName(item.first);
+}
+
 const std::map<std::string, SupportedCaffe2OpType> Caffe2Importer::_operatorTypes = {
+        {"Add",             SupportedCaffe2OpType::add},
         {"AveragePool",     SupportedCaffe2OpType::averagePool},
         {"Conv",            SupportedCaffe2OpType::conv},
+        {"Concat",          SupportedCaffe2OpType::concat},
+        {"ConstantFill",    SupportedCaffe2OpType::constantFill},
         {"Dropout",         SupportedCaffe2OpType::dropout},
         {"FC",              SupportedCaffe2OpType::FC},
         {"GivenTensorFill", SupportedCaffe2OpType::givenTensorFill},
         {"MaxPool",         SupportedCaffe2OpType::maxPool},
+        {"Mul",             SupportedCaffe2OpType::mul},
         {"Relu",            SupportedCaffe2OpType::relu},
         {"Softmax",         SupportedCaffe2OpType::softmax},
+        {"SpatialBN",       SupportedCaffe2OpType::spatialBN},
         {"Sum",             SupportedCaffe2OpType::sum}
 };
 
diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp
new file mode 100644 (file)
index 0000000..57b4652
--- /dev/null
@@ -0,0 +1,328 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/BatchNormOp.h"
+#include "core/modelIR/operations/BiasAddOp.h"
+#include "core/modelIR/operations/CappedReluOp.h"
+#include "core/modelIR/operations/ConcatOp.h"
+#include "core/modelIR/operations/Conv2DOp.h"
+#include "core/modelIR/operations/DepthwiseConv2DOp.h"
+#include "core/modelIR/operations/DropoutOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/PoolOp.h"
+#include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ReshapeOp.h"
+#include "core/modelIR/operations/ScaleOp.h"
+#include "core/modelIR/operations/SoftmaxOp.h"
+#include "core/modelIR/operations/TransposeOp.h"
+#include "core/modelIR/operations/VariableOp.h"
+
+#include "core/modelIR/Index.h"
+#include "core/modelIR/Shape.h"
+#include "core/modelIR/ShapeRange.h"
+#include "core/modelIR/Tensor.h"
+#include "core/modelIR/TensorUtil.h"
+
+#include "passes/common_frontend/op_creator_helper.h"
+#include "passes/common_frontend/shape_helper.h"
+#include "pass/PassException.h"
+#include "caffe2_op_creator.h"
+#include "caffe2_proto_helper.h"
+
+#include <cmath>
+#include <set>
+#include <vector>
+#include "option/Options.h"
+
+
+namespace nnc {
+
+using namespace ::caffe2;
+using namespace mir;
+using nnc::mir::transposeTensor;
+
+//
+// Helper functions
+//
+
+mir::IODescriptor Caffe2OpCreator::convertCaffeToMIR(const mir::IODescriptor& arg) {
+  if (cli::debugTranspose) {
+    // NCHW -> NHWC
+    auto transpose = createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 2, 3, 1});
+    return transpose->getOutput(0);
+  } else {
+    return arg;
+  }
+}
+
+mir::IODescriptor Caffe2OpCreator::convertMIRToCaffe(const mir::IODescriptor& arg) {
+  if (cli::debugTranspose) {
+    // NHWC -> NCHW
+    auto transpose = createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 3, 1, 2});
+    return transpose->getOutput(0);
+  } else {
+    return arg;
+  }
+}
+
+//
+// Check functions
+//
+
+void Caffe2OpCreator::commonCheck(const ::caffe2::OperatorDef& op,
+                                  std::set<std::string>& problemsOpSet) {
+  if (getSingleArgument(op, "order", "NCHW") != "NCHW")
+    problemsOpSet.insert("Only 'NCHW' oreder is supported");
+}
+
+void Caffe2OpCreator::checkFC(const ::caffe2::OperatorDef& op,
+                              std::set<std::string>& problemsOpSet) {
+  commonCheck(op, problemsOpSet);
+  for (auto& s : {"axis", "axis_w", "float16_compute"})
+    if (hasArgument(op.arg(), s))
+      problemsOpSet.insert(std::string("FC: only default '") + s + "' value is supported");
+}
+
+void Caffe2OpCreator::checkSpatialBN(const ::caffe2::OperatorDef& op,
+                                     std::set<std::string>& problemsOpSet) {
+  commonCheck(op, problemsOpSet);
+  if (op.input_size() != 5)
+    problemsOpSet.insert(
+            "SpatialBN must have exactly 5 inputs ('sums' and 'sumsq' are not supported yet)");
+}
+
+//
+// Convert functions
+//
+
+std::vector<mir::IODescriptor>
+Caffe2OpCreator::convertAdd(const std::vector<mir::IODescriptor>& inputs,
+                            const ::caffe2::OperatorDef& op,
+                            const MIRTensors& mirTensors) {
+  // TODO: not tested
+  throw PassException("Caffe2 Add op not tested yet");
+  auto& addend = mirTensors.at(op.input(1));
+  auto add = createOp<ops::BiasAddOp>(inputs[0], *addend);
+  return {add->getOutput(0)};
+}
+
+std::vector<IODescriptor>
+Caffe2OpCreator::convertAveragePool(const std::vector<IODescriptor>& inputs,
+                                    const OperatorDef& op) {
+  // TODO: implement custom paddings
+  bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r")
+                        || hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b");
+  if (has_custom_pad)
+    throw PassException("Custom one-side padding not supported yet");
+
+  int kernel_size = static_cast<int>(findArgumentByName(op.arg(), "kernel").i());
+  Shape window_shape = Shape({kernel_size, kernel_size});
+
+  int stride = static_cast<int>(findArgumentByName(op.arg(), "stride").i());
+  Shape strides = Shape({stride, stride});
+
+  ops::PoolOp::PoolingType pool_type = ops::PoolOp::PoolingType::AVG;
+  ops::PoolOp::BorderType border_type = ops::PoolOp::BorderType::ZEROFILLED;
+
+  int pad = getSingleArgument(op, "pad", 0);
+  std::vector<int32_t> padding{pad, pad};
+
+  auto pooling = createOp<ops::PoolOp>(inputs[0], pool_type, window_shape, strides, padding,
+                                       padding, border_type, ops::PoolOp::RoundMode::ceil);
+
+  return {pooling->getOutput(0)};
+}
+
+std::vector<IODescriptor> Caffe2OpCreator::convertConv(const std::vector<IODescriptor>& inputs,
+                                                       const ::caffe2::OperatorDef& op,
+                                                       const MIRTensors& mirTensors) {
+  int stride = getSingleArgument(op, "stride", 1);
+  Shape stride_shape = Shape({stride, stride});
+
+  int pad = getSingleArgument(op, "pad", 0);
+  std::vector<int32_t> padding{pad, pad};
+
+  auto kernel_tensor = transposeTensor<2, 3, 1, 0>(mirTensors.at(op.input(1)));
+  auto in_group_size = kernel_tensor->getShape().dim(2);
+  auto out_channels = kernel_tensor->getShape().dim(3);
+  int num_groups = getSingleArgument(op, "group", 1);
+  bool is_depthwise = (num_groups != 1) && (in_group_size == 1) && (out_channels == num_groups);
+
+  mir::Operation* conv2d;
+  if (is_depthwise) {
+    // This is depthwise convolution
+    // TODO handle properly kernel with layer multiplier
+    std::shared_ptr<IrTensor> transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(kernel_tensor);
+    conv2d = createOp<ops::DepthwiseConv2DOp>(convertCaffeToMIR(inputs[0]), *transposed_tensor,
+                                              stride_shape, padding, padding);
+  } else {
+    // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
+    if (num_groups != 1)
+      kernel_tensor = fixGroupedKernel(num_groups, kernel_tensor);
+
+    conv2d = createOp<ops::Conv2DOp>(convertCaffeToMIR(inputs[0]), *kernel_tensor,
+                                     stride_shape, padding, padding);
+  }
+
+  if (op.input_size() > 2) {  // Bias is optional
+    auto bias_add = createOp<ops::BiasAddOp>(conv2d->getOutput(0), *mirTensors.at(op.input(2)));
+    return {convertMIRToCaffe(bias_add->getOutput(0))};
+  }
+  return {convertMIRToCaffe(conv2d->getOutput(0))};
+}
+
+std::vector<IODescriptor> Caffe2OpCreator::convertConcat(const std::vector<IODescriptor>& inputs,
+                                                         const ::caffe2::OperatorDef& op) {
+  // TODO: not tested
+  throw PassException("Caffe2 Concat op not tested yet");
+  int axis = getSingleArgument(op, "axis", -1);
+  auto result = createOp<ops::ConcatOp>(inputs, axis);
+  return {result->getOutput(0)};
+}
+
+std::vector<IODescriptor> Caffe2OpCreator::convertDropout(const std::vector<IODescriptor>& inputs,
+                                                          const ::caffe2::OperatorDef& op) {
+  // TODO: not tested
+  throw PassException("Caffe2 Dropout op not tested yet");
+  int is_test = getSingleArgument(op, "is_test", 0);
+  if (is_test)
+    return {inputs[0]};
+
+  float dropot_ratio = getSingleArgument(op, "ratio", 0.5f);
+  auto dropout = createOp<ops::DropoutOp>(inputs[0], dropot_ratio);
+  return {dropout->getOutput(0)};
+}
+
+// TODO: describe caffe2 FC interface
+std::vector<IODescriptor>
+Caffe2OpCreator::convertFullyConnected(const std::vector<IODescriptor>& inputs,
+                                       const ::caffe2::OperatorDef& op,
+                                       const MIRTensors& mirTensors) {
+  auto weightsTensor = mirTensors.at(op.input(1));
+  weightsTensor = transposeTensor<1, 0>(weightsTensor);
+  int32_t fc_input_size = weightsTensor->getShape().dim(0);
+
+  // Add Reshape operation to make sure the input for FC operation has shape [1, fcInputSize]
+  // It is needed because Caffe2 FC layer takes NCHW input and flattens the CHW part.
+  auto reshape = createOp<ops::ReshapeOp>(inputs[0], Shape({1, fc_input_size}));
+
+  auto fully_connected = createOp<ops::FullyConnectedOp>(reshape->getOutput(0), *weightsTensor);
+
+  auto bias = createOp<ops::BiasAddOp>(fully_connected->getOutput(0), *mirTensors.at(op.input(2)));
+  return {bias->getOutput(0)};
+}
+
+std::vector<IODescriptor>
+Caffe2OpCreator::createInput(const std::string& input_name, const mir::Shape& input_shape) {
+  // TODO For now we only support convolutional networks with one element per batch.
+  assert(input_shape.rank() == 4 && input_shape.dim(0) == 1);
+
+  // TODO Do not transpose data on input and remove transpose.
+  auto transposed_shape = mir::Shape{input_shape.dim(0), input_shape.dim(2),
+                                     input_shape.dim(3), input_shape.dim(1)};
+  auto variable = _graph->create<ops::VariableOp>(input_name, transposed_shape);
+  return {convertMIRToCaffe(variable->getOutput(0))};
+}
+
+std::vector<IODescriptor> Caffe2OpCreator::convertMaxPool(const std::vector<IODescriptor>& inputs,
+                                                          const OperatorDef& op) {
+  // TODO: implement custom paddings
+  bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r")
+                        || hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b");
+  if (has_custom_pad)
+    throw PassException("Custom one-side padding not supported yet");
+
+  int window_length = static_cast<int>(findArgumentByName(op.arg(), "kernel").i());
+  Shape window_shape = Shape({window_length, window_length});
+
+  int stride = static_cast<int>(findArgumentByName(op.arg(), "stride").i());
+  Shape strides = Shape({stride, stride});
+
+  ops::PoolOp::PoolingType pool_type = ops::PoolOp::PoolingType::MAX;
+  ops::PoolOp::BorderType border_type = ops::PoolOp::BorderType::EMPTY;
+
+  int pad = getSingleArgument(op, "pad", 0);
+  std::vector<int32_t> padding{pad, pad};
+
+  auto pooling = createOp<ops::PoolOp>(convertCaffeToMIR(inputs[0]), pool_type, window_shape,
+                                       strides, padding, padding, border_type,
+                                       ops::PoolOp::RoundMode::ceil);
+
+  return {convertMIRToCaffe(pooling->getOutput(0))};
+}
+
+std::vector<mir::IODescriptor>
+Caffe2OpCreator::convertMul(const std::vector<mir::IODescriptor>& inputs,
+                            const ::caffe2::OperatorDef& op,
+                            const MIRTensors& mirTensors) {
+  // TODO: not tested
+  throw PassException("Caffe Mul op not tested yet");
+  auto& multiplier = mirTensors.at(op.input(1));
+  auto mul = createOp<ops::ScaleOp>(inputs[0], *multiplier);
+  return {mul->getOutput(0)};
+}
+
+std::vector<IODescriptor> Caffe2OpCreator::convertRelu(const std::vector<IODescriptor>& inputs) {
+  auto relu = createOp<ops::ReluOp>(inputs[0]);
+  return {relu->getOutput(0)};
+}
+
+std::vector<IODescriptor> Caffe2OpCreator::convertSoftmax(const std::vector<IODescriptor>& inputs,
+                                                          const ::caffe2::OperatorDef& op) {
+  int axis = getSingleArgument(op, "axis", 1);
+  auto softmax = createOp<ops::SoftmaxOp>(inputs[0], axis);
+  return {softmax->getOutput(0)};
+}
+
+std::vector<mir::IODescriptor>
+Caffe2OpCreator::convertSpatialBN(const std::vector<mir::IODescriptor>& inputs,
+                                  const ::caffe2::OperatorDef& op,
+                                  const MIRTensors& mirTensors) {
+  // TODO: not tested
+  throw PassException("Caffe2 SpatialBN op not tested yet");
+  // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias
+
+  auto& scale = mirTensors.at(op.input(1));
+  auto& bias = mirTensors.at(op.input(2));
+  auto& mean = mirTensors.at(op.input(3));
+  auto& var = mirTensors.at(op.input(4));
+  float eps = getSingleArgument(op, "epsilon", 1e-5f);
+
+  // res1 = X - mean
+  Tensor<float> bias_data(*mean);
+  for (Index idx: ShapeRange(bias_data.getShape()))
+    bias_data.at(idx) *= -1;
+  auto bias_add_1 = createOp<ops::BiasAddOp>(convertCaffeToMIR(inputs[0]), *mean);
+
+  // res2 = res1 * scale / (var + epsilon)
+  Tensor<float> multiplier(*scale);
+  for (Index idx: ShapeRange(scale->getShape()))
+    multiplier.at(idx) = 1.0f / std::sqrt(*(float*) var->at(idx) + eps);
+  auto scale_op = createOp<ops::ScaleOp>(bias_add_1->getOutput(0), *scale);
+
+  // overall_res = res2 + bias
+  auto bias_add_2 = createOp<ops::BiasAddOp>(scale_op->getOutput(0), *bias);
+
+  return {convertMIRToCaffe(bias_add_2->getOutput(0))};
+}
+
+std::vector<IODescriptor> Caffe2OpCreator::convertSum(const std::vector<IODescriptor>& inputs) {
+  auto op = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::add);
+  return {op->getOutput(0)};
+}
+
+} // namespace nnc
diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h b/contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h
new file mode 100644 (file)
index 0000000..09b2d0e
--- /dev/null
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_CAFFE2_OP_CREATOR_H
+#define NNCC_CAFFE2_OP_CREATOR_H
+
+#include <set>
+#include <map>
+#include <vector>
+#include <memory>
+
+#include "core/modelIR/Graph.h"
+#include "core/modelIR/Operation.h"
+#include "core/modelIR/TensorVariant.h"
+#include "core/modelIR/operations/CommonProps.h"
+#include "core/modelIR/Shape.h"
+
+#include "caffe2/proto/caffe2.pb.h"
+
+namespace nnc {
+
+using nnc::mir::Graph;
+using nnc::mir::Operation;
+using IrTensor = nnc::mir::TensorVariant;
+using nnc::mir::Shape;
+using MIRTensors = const std::map<std::string, std::shared_ptr<mir::TensorVariant>>;
+
+class Caffe2OpCreator {
+public:
+  explicit Caffe2OpCreator(Graph* g) : _graph(g) {};
+
+  void commonCheck(const ::caffe2::OperatorDef&, std::set<std::string>&);
+
+  void checkFC(const ::caffe2::OperatorDef&, std::set<std::string>&);
+
+  void checkSpatialBN(const ::caffe2::OperatorDef&, std::set<std::string>&);
+
+  std::vector<mir::IODescriptor> convertAdd(const std::vector<mir::IODescriptor>&,
+                                            const ::caffe2::OperatorDef&, const MIRTensors&);
+
+  std::vector<mir::IODescriptor> convertAveragePool(const std::vector<mir::IODescriptor>&,
+                                                    const ::caffe2::OperatorDef&);
+
+  std::vector<mir::IODescriptor> convertConv(const std::vector<mir::IODescriptor>&,
+                                             const ::caffe2::OperatorDef&, const MIRTensors&);
+
+  std::vector<mir::IODescriptor> convertConcat(const std::vector<mir::IODescriptor>&,
+                                               const ::caffe2::OperatorDef&);
+
+  std::vector<mir::IODescriptor> convertDropout(const std::vector<mir::IODescriptor>&,
+                                                const ::caffe2::OperatorDef&);
+
+  std::vector<mir::IODescriptor> convertFullyConnected(const std::vector<mir::IODescriptor>&,
+                                                       const ::caffe2::OperatorDef&,
+                                                       const MIRTensors&);
+
+  std::vector<mir::IODescriptor> createInput(const std::string&, const mir::Shape&);
+
+  std::vector<mir::IODescriptor> convertMaxPool(const std::vector<mir::IODescriptor>&,
+                                                const ::caffe2::OperatorDef&);
+
+  std::vector<mir::IODescriptor> convertMul(const std::vector<mir::IODescriptor>&,
+                                            const ::caffe2::OperatorDef&, const MIRTensors&);
+
+  std::vector<mir::IODescriptor> convertRelu(const std::vector<mir::IODescriptor>&);
+
+  std::vector<mir::IODescriptor> convertSoftmax(const std::vector<mir::IODescriptor>&,
+                                                const ::caffe2::OperatorDef&);
+
+  std::vector<mir::IODescriptor> convertSpatialBN(const std::vector<mir::IODescriptor>&,
+                                                  const ::caffe2::OperatorDef&, const MIRTensors&);
+
+  std::vector<mir::IODescriptor> convertSum(const std::vector<mir::IODescriptor>&);
+
+private:
+  Graph* _graph = nullptr;
+
+  mir::IODescriptor convertCaffeToMIR(const mir::IODescriptor& arg);
+
+  mir::IODescriptor convertMIRToCaffe(const mir::IODescriptor& arg);
+
+  template <typename OpType, typename ...Types>
+  mir::Operation* createOp(Types&& ... args);
+};
+
+template <typename OpType, typename ...Types>
+mir::Operation* Caffe2OpCreator::createOp(Types&& ... args) {
+  // TODO: set operation names
+  return _graph->create<OpType>("", std::forward<Types>(args)...);
+}
+
+} // namespace nnc
+
+#endif //NNCC_CAFFE2_OP_CREATOR_H
index 6aa56e7..8ac7260 100644 (file)
 namespace nnc {
 
 enum class SupportedCaffe2OpType : uint8_t {
+  add,
   averagePool,
+  concat,
   conv,
+  constantFill,
   dropout,
   FC,
   givenTensorFill,
   maxPool,
+  mul,
   relu,
   softmax,
+  spatialBN,
   sum
 };
 
index b13436c..8e79e89 100644 (file)
@@ -36,4 +36,25 @@ const bool hasArgument(RepArgument args, std::string name) {
   return false;
 }
 
+int getSingleArgument(const ::caffe2::OperatorDef& op, const std::string& argument_name,
+                      const int default_value) {
+  if (hasArgument(op.arg(), argument_name))
+    return static_cast<int>(findArgumentByName(op.arg(), argument_name).i());
+  return default_value;
+}
+
+float getSingleArgument(const ::caffe2::OperatorDef& op, const std::string& argument_name,
+                        const float default_value) {
+  if (hasArgument(op.arg(), argument_name))
+    return findArgumentByName(op.arg(), argument_name).f();
+  return default_value;
+}
+
+std::string getSingleArgument(const ::caffe2::OperatorDef& op, const std::string& argument_name,
+                              const std::string& default_value) {
+  if (hasArgument(op.arg(), argument_name))
+    return findArgumentByName(op.arg(), argument_name).s();
+  return default_value;
+}
+
 } // namespace nnc
index 750f396..5b62e23 100644 (file)
@@ -22,8 +22,13 @@ namespace nnc {
 using RepArgument = const ::google::protobuf::RepeatedPtrField<::caffe2::Argument>&;
 
 const ::caffe2::Argument& findArgumentByName(RepArgument args, std::string name);
+
 const bool hasArgument(RepArgument args, std::string name);
 
+int getSingleArgument(const ::caffe2::OperatorDef&, const std::string&, const int);
+float getSingleArgument(const ::caffe2::OperatorDef&, const std::string&, const float);
+std::string getSingleArgument(const ::caffe2::OperatorDef&, const std::string&, const std::string&);
+
 } // namespace nnc
 
 #endif // NNCC_CAFFE2_PROTO_HELPER_H
index c0cb96a..caff041 100644 (file)
@@ -39,6 +39,7 @@
 #include "core/modelIR/TensorUtil.h"
 
 #include "passes/common_frontend/shape_helper.h"
+#include "passes/common_frontend/op_creator_helper.h"
 #include "pass/PassException.h"
 #include "caffe_op_creator.h"
 
@@ -52,65 +53,6 @@ namespace nnc {
 using namespace mir;
 using namespace ::caffe;
 
-/** Convert kernel for grouped 2d convolution in kernel for ordinary 2d convolution
- *
- * Grouped convolution breaks input and kernel channels into selected number of groups and applies convolution in every group of channels independently.
- * This technique allows to save kernel size(channels from different groups are not merged, no need to store redundant 0 weights).
- * This is not supported by compiler for now, so this function unfolds compact kernel into classic flavored "every input layer affects every output layer",
- * by inserting zero coefficients where needed
- *
- * @param groups number of groups in grouped convolution
- * @param foldedKernel original grouped kernel
- * @return unfolded kernel, compatible with ordinary conv2D operation
- */
-static std::shared_ptr<IrTensor>
-fixGroupedKernel(int groups, std::shared_ptr<IrTensor> folded_kernel) {
-  const int kernel_in_chan_num = 2;
-  const int kernel_out_chan_num = 3;
-
-  const Shape& kernel_shape = folded_kernel->getShape();
-  auto kernel_in_channels = kernel_shape.dim(kernel_in_chan_num);
-  auto kernel_out_channels = kernel_shape.dim(kernel_out_chan_num);
-  auto in_channels = kernel_in_channels * groups;
-
-  // Original kernel has shape [H, W, inputChannels/groups, outputChannels]
-  // here creates unfolded kernel with shape [H, W, inputChannels, outputChannels]
-  Shape unfold_kernel_shape(kernel_shape);
-  unfold_kernel_shape.dim(kernel_in_chan_num) = in_channels;
-  auto buffer_size = unfold_kernel_shape.numElements() * folded_kernel->getElementSize();
-  std::shared_ptr<char> buffer(new char[buffer_size], std::default_delete<char[]>());
-  size_t data_size = folded_kernel->getElementSize();
-  std::shared_ptr<IrTensor> unfold_kernel =
-          std::make_shared<IrTensor>(unfold_kernel_shape, buffer, folded_kernel->getDataType(),
-                                     data_size);
-
-  int in_group_size = kernel_in_channels;
-  int out_group_size = kernel_out_channels / groups;
-  assert(kernel_out_channels % groups == 0);
-
-  // Iterate over "unfolded" kernel Shape and insert appropriate values into result kernel
-  for (const mir::Index& idx: mir::ShapeRange(unfold_kernel_shape)) {
-    auto in_group_no = idx.at(kernel_in_chan_num) / in_group_size;
-    auto out_group_no = idx.at(kernel_out_chan_num) / out_group_size;
-    // check that input channel group fits output channel group
-    if (in_group_no == out_group_no) {
-      // compute index in original kernel that corresponds output index
-      mir::Index folded_idx(idx);
-      folded_idx.at(kernel_in_chan_num) %= in_group_size;
-
-      std::copy(folded_kernel->at(folded_idx), folded_kernel->at(folded_idx) + data_size,
-                unfold_kernel->at(idx));
-    } else {
-      // fill element of output kernel with zero element
-      assert(folded_kernel->getDataType() == DTYPE::FLOAT32 &&
-             "unsupported data type, add appropriate zero element creation");
-      float* elem = reinterpret_cast<float*>(unfold_kernel->at(idx));
-      *elem = 0.0f;
-    }
-  }
-  return unfold_kernel;
-}
-
 mir::IODescriptor CaffeOpCreator::convertCaffeToMIR(const mir::IODescriptor& arg) {
   if (cli::debugTranspose) {
     // NCHW -> NHWC
index 84df465..cb3cc2e 100644 (file)
@@ -2,8 +2,7 @@
 # Common for every importer code library #
 ##########################################
 
-set(COMMON_SOURCES
-        model_allocation.cpp)
+set(COMMON_SOURCES model_allocation.cpp op_creator_helper.cpp)
 
 add_library(nn_import_common STATIC ${COMMON_SOURCES})
 set_target_properties(nn_import_common PROPERTIES POSITION_INDEPENDENT_CODE ON)
diff --git a/contrib/nnc/passes/common_frontend/op_creator_helper.cpp b/contrib/nnc/passes/common_frontend/op_creator_helper.cpp
new file mode 100644 (file)
index 0000000..7bc0315
--- /dev/null
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/common_frontend/op_creator_helper.h"
+
+#include "core/modelIR/Shape.h"
+#include "core/modelIR/ShapeRange.h"
+#include "core/modelIR/TensorVariant.h"
+
+namespace nnc {
+
+using namespace mir;
+
+std::shared_ptr<TensorVariant>
+fixGroupedKernel(int groups, std::shared_ptr<TensorVariant> folded_kernel) {
+  const int kernel_in_chan_num = 2;
+  const int kernel_out_chan_num = 3;
+
+  const Shape& kernel_shape = folded_kernel->getShape();
+  auto kernel_in_channels = kernel_shape.dim(kernel_in_chan_num);
+  auto kernel_out_channels = kernel_shape.dim(kernel_out_chan_num);
+  auto in_channels = kernel_in_channels * groups;
+
+  // Original kernel has shape [H, W, inputChannels/groups, outputChannels]
+  // here creates unfolded kernel with shape [H, W, inputChannels, outputChannels]
+  Shape unfold_kernel_shape(kernel_shape);
+  unfold_kernel_shape.dim(kernel_in_chan_num) = in_channels;
+  auto buffer_size = unfold_kernel_shape.numElements() * folded_kernel->getElementSize();
+  std::shared_ptr<char> buffer(new char[buffer_size], std::default_delete<char[]>());
+  size_t data_size = folded_kernel->getElementSize();
+  std::shared_ptr<TensorVariant> unfold_kernel =
+          std::make_shared<TensorVariant>(unfold_kernel_shape, buffer, folded_kernel->getDataType(),
+                                          data_size);
+
+  int in_group_size = kernel_in_channels;
+  int out_group_size = kernel_out_channels / groups;
+  assert(kernel_out_channels % groups == 0);
+
+  // Iterate over "unfolded" kernel Shape and insert appropriate values into result kernel
+  for (const mir::Index& idx: mir::ShapeRange(unfold_kernel_shape)) {
+    auto in_group_no = idx.at(kernel_in_chan_num) / in_group_size;
+    auto out_group_no = idx.at(kernel_out_chan_num) / out_group_size;
+    // check that input channel group fits output channel group
+    if (in_group_no == out_group_no) {
+      // compute index in original kernel that corresponds output index
+      mir::Index folded_idx(idx);
+      folded_idx.at(kernel_in_chan_num) %= in_group_size;
+
+      std::copy(folded_kernel->at(folded_idx), folded_kernel->at(folded_idx) + data_size,
+                unfold_kernel->at(idx));
+    } else {
+      // fill element of output kernel with zero element
+      assert(folded_kernel->getDataType() == DTYPE::FLOAT32 &&
+             "unsupported data type, add appropriate zero element creation");
+      float* elem = reinterpret_cast<float*>(unfold_kernel->at(idx));
+      *elem = 0.0f;
+    }
+  }
+  return unfold_kernel;
+}
+
+}  // namespace nnc
\ No newline at end of file