[nnc] Use Transpose operation in Caffe importer (#2468)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Wed, 5 Dec 2018 12:47:37 +0000 (15:47 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 5 Dec 2018 12:47:37 +0000 (15:47 +0300)
Use Transpose operation in Caffe importer to switch between Caffe NCHW and ModelIR NHWC formats.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
contrib/nnc/driver/Options.cpp
contrib/nnc/include/core/modelIR/operations/TransposeOp.h
contrib/nnc/include/option/Options.h
contrib/nnc/passes/caffe_frontend/caffe_importer.cpp
contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp
contrib/nnc/passes/caffe_frontend/caffe_op_creator.h

index ba6ea9e..317dc48 100644 (file)
@@ -162,6 +162,13 @@ Option<std::string> interInputData(optname("--input-model-data"),
                                    optional(true),
                                    optvalues(""),
                                    checkInFile);
+/**
+ * Miscellaneous options.
+ */
+Option<bool> debugTranspose(optname("--debug-transpose"),
+                           overview("insert transpose operations for debugging purposes"),
+                           false,
+                           optional(true));
 
 } // namespace cli
 } // namespace nnc
index 31f8d58..7f11298 100644 (file)
@@ -24,6 +24,11 @@ namespace nnc {
 namespace mir {
 namespace ops {
 
+/**
+ * @brief Tensor transpose operation.
+ *
+ * Rearranges axes of input tensor.
+ */
 class TransposeOp : public Operation {
 public:
   TransposeOp(const IODescriptor& arg, const std::vector<std::size_t>& axis_order);
index 9a605ff..0ce2c4b 100644 (file)
@@ -57,6 +57,12 @@ extern Option<std::string> artifactName;  // name of artifact
  * Options for interpreter
  */
 extern Option<std::string> interInputData;  // input data for model
+
+/**
+ * Miscellaneous options.
+ */
+extern Option<bool> debugTranspose;
+
 } // namespace cli
 } // namespace nnc
 
index 295e408..1723bf8 100644 (file)
@@ -156,6 +156,7 @@ void CaffeImporter::collectUnsupportedOp(const LayerParameter& lp) {
   std::vector<std::shared_ptr<IrTensor>> params;
 
   switch (op_type) {
+    case CaffeOpType::concat:
     case CaffeOpType::input:
     case CaffeOpType::softmax:
     case CaffeOpType::scale:
@@ -166,9 +167,6 @@ void CaffeImporter::collectUnsupportedOp(const LayerParameter& lp) {
     case CaffeOpType::tanh:
       // No checks
       break;
-    case CaffeOpType::concat:
-      _opCreator->checkConcat(lp.concat_param(), _problemsOpSet);
-      break;
     case CaffeOpType::deconvolution:
     case CaffeOpType::convolution:
       _opCreator->checkConvolution(lp.convolution_param(), _problemsOpSet);
index d66563a..feaef45 100644 (file)
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/BatchNormOp.h"
 #include "core/modelIR/operations/DropoutOp.h"
-#include <core/modelIR/operations/ElementwiseOp.h>
-#include <core/modelIR/operations/Deconv2DOp.h>
-#include <core/modelIR/operations/TanhOp.h>
-#include <core/modelIR/operations/EluOp.h>
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/Deconv2DOp.h"
+#include "core/modelIR/operations/TanhOp.h"
+#include "core/modelIR/operations/TransposeOp.h"
+#include "core/modelIR/operations/EluOp.h"
 
 #include "core/modelIR/Index.h"
 #include "core/modelIR/ShapeRange.h"
@@ -44,6 +45,7 @@
 #include <set>
 #include <cmath>
 #include <iostream>
+#include "option/Options.h"
 
 namespace nnc {
 
@@ -129,33 +131,6 @@ static ops::PoolOp::PoolingType getPoolingType(const PoolingParameter& opts) {
                         PoolingParameter::PoolMethod_Name(opts.pool()));
 }
 
-/**
- * @brief Determines correct value for Caffe Softmax/Concat axis parameter.
- * @todo Change cout to a log library call.
- * @todo Decide how to process axis in general.
- */
-template <typename OptsType>
-static int getAxisValue(const OptsType& opts) {
-  // -1 represents last one dimension
-  int axis = -1;
-  if (opts.has_axis()) {
-    axis = opts.axis();
-    if (axis == 0)
-      std::cout << "WARNING: axis parameter equals 0. It is normal,"
-                   "but implies that the model might not have a batch dimension,"
-                   "so make sure import works correctly." << std::endl;
-    else if (axis != 1 && axis != -1)
-      throw PassException("Softmax/Concat layer axis param is not 1 or -1, which implies"
-                          "unsupported NN architecture.");
-  }
-
-  // axis 1 represents channels in caffe, in Model ir it is second dimension for now
-  if (axis == 1)
-    return 3;
-
-  return axis;
-}
-
 /** Convert kernel for grouped 2d convolution in kernel for ordinary 2d convolution
  *
  * Grouped convolution breaks input and kernel channels into selected number of groups and applies convolution in every group of channels independently.
@@ -215,6 +190,26 @@ fixGroupedKernel(int groups, std::shared_ptr<IrTensor> folded_kernel) {
   return unfold_kernel;
 }
 
+mir::IODescriptor CaffeOpCreator::convertCaffeToMIR(const mir::IODescriptor& arg) {
+  if (cli::debugTranspose) {
+    // NCHW -> NHWC
+    auto transpose = createOp<ops::TransposeOp>("", arg, std::vector<std::size_t>{0, 2, 3, 1});
+    return transpose->getOutput(0);
+  } else {
+    return arg;
+  }
+}
+
+mir::IODescriptor CaffeOpCreator::convertMIRToCaffe(const mir::IODescriptor& arg) {
+  if (cli::debugTranspose) {
+    // NHWC -> NCHW
+    auto transpose = createOp<ops::TransposeOp>("", arg, std::vector<std::size_t>{0, 3, 1, 2});
+    return transpose->getOutput(0);
+  } else {
+    return arg;
+  }
+}
+
 std::vector<mir::IODescriptor>
 CaffeOpCreator::convertInput(const LayerParameter& layer) {
   const auto& params = layer.input_param();
@@ -229,18 +224,13 @@ CaffeOpCreator::convertInput(const LayerParameter& layer) {
     const auto& blob_shape = params.shape(num_shapes == 1 ? 0 : i);
     Shape shape = ShapeHelper::createShape(blob_shape.dim(), blob_shape.dim_size());
 
-    // TODO For now we only support convolutional networks. The input data have already been
-    // transformed from Caffe NCHW format to ModelIR NHWC; reflect the changes in the IR.
-    assert(shape.rank() == 4);
-    shape = Shape{shape.dim(0), shape.dim(2), shape.dim(3), shape.dim(1)};
-
-    // TODO Remove this limitation.
-    assert(shape.dim(0) == 1);
+    // TODO For now we only support convolutional networks with one element per batch.
+    assert(shape.rank() == 4 && shape.dim(0) == 1);
 
-    // FIXME We cannot use CaffeOpCreator::createOp here, because we have to set name instantly.
-    // Otherwise interpreter backend won't work.
+    // TODO Do not transpose data on input and remove transpose.
+    shape = Shape{shape.dim(0), shape.dim(2), shape.dim(3), shape.dim(1)};
     auto variable = createOp<ops::VariableOp>(blob_name, shape);
-    descriptors.push_back(variable->getOutput(0));
+    descriptors.push_back(convertMIRToCaffe(variable->getOutput(0)));
   }
 
   return descriptors;
@@ -281,23 +271,24 @@ CaffeOpCreator::convertConvolution(const caffe::LayerParameter& layer,
     // This is depthwise convolution
     // TODO handle properly kernel with layer multiplier
     std::shared_ptr<IrTensor> transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(params[0]);
-    conv2d = createOp<ops::DepthwiseConv2DOp>(layer.name(), inputs[0], *transposed_tensor, strides,
-                                              paddings);
+    conv2d = createOp<ops::DepthwiseConv2DOp>(layer.name(), convertCaffeToMIR(inputs[0]),
+                                              *transposed_tensor, strides, paddings);
   } else {
     if (num_groups != 1) {
       // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
       unfolded_tensor = fixGroupedKernel(opts.group(), params[0]);
     }
-    conv2d = createOp<ops::Conv2DOp>(layer.name(), inputs[0], *unfolded_tensor, strides, paddings);
+    conv2d = createOp<ops::Conv2DOp>(layer.name(), convertCaffeToMIR(inputs[0]), *unfolded_tensor,
+                                     strides, paddings);
   }
 
   // bias_term is optional (so might not be present) and defaults to true
   if (!opts.has_bias_term() || opts.bias_term()) {
     auto bias_add = createOp<ops::BiasAddOp>(layer.name() + ".bias", conv2d->getOutput(0),
                                              *params[1]);
-    return {bias_add->getOutput(0)};
+    return {convertMIRToCaffe(bias_add->getOutput(0))};
   } else {
-    return {conv2d->getOutput(0)};
+    return {convertMIRToCaffe(conv2d->getOutput(0))};
   }
 }
 
@@ -343,21 +334,19 @@ CaffeOpCreator::convertInnerProduct(const LayerParameter& layer,
   }
 }
 
-void CaffeOpCreator::checkConcat(const caffe::ConcatParameter& opts,
-                                 std::set<std::string>& problemsOpSet) {
-  if (opts.axis() != 1)
-    problemsOpSet.insert("Concat: unsupported axis");
-}
-
 std::vector<mir::IODescriptor>
 CaffeOpCreator::convertConcat(const caffe::LayerParameter& layer,
                               const std::vector<mir::IODescriptor>& inputs) {
   auto& opts = layer.concat_param();
-  // NCHW -> NHWC
-  assert(opts.axis() == 1);
-  int32_t axis = 3;
-  auto concat = createOp<ops::ConcatOp>(layer.name(), inputs, axis);
-  return {concat->getOutput(0)};
+  if (cli::debugTranspose) {
+    auto concat = createOp<ops::ConcatOp>(layer.name(), inputs, opts.axis());
+    return {concat->getOutput(0)};
+  } else {
+    assert(opts.axis() == 1);
+    int32_t axis = 3;
+    auto concat = createOp<ops::ConcatOp>(layer.name(), inputs, axis);
+    return {concat->getOutput(0)};
+  }
 }
 
 void CaffeOpCreator::checkPooling(const PoolingParameter& opts,
@@ -401,23 +390,42 @@ CaffeOpCreator::convertPooling(const caffe::LayerParameter& layer,
       assert(false);
   }
 
-  auto pooling = createOp<ops::PoolOp>(layer.name(), inputs[0], window_shape, strides, pool_type,
-                                       paddings, border_type, ops::PoolOp::RoundMode::ceil);
-
-  return {pooling->getOutput(0)};
+  auto pooling = createOp<ops::PoolOp>(layer.name(), convertCaffeToMIR(inputs[0]), window_shape,
+                                       strides, pool_type, paddings, border_type,
+                                       ops::PoolOp::RoundMode::ceil);
+  return {convertMIRToCaffe(pooling->getOutput(0))};
 }
 
 std::vector<mir::IODescriptor>
 CaffeOpCreator::convertSoftmax(const caffe::LayerParameter& layer,
                                const std::vector<mir::IODescriptor>& inputs) {
-  assert(inputs.size() == 1);
   auto& opts = layer.softmax_param();
-  auto input = inputs[0];
-  auto& input_shape = input.op->getOutputShape(input.index);
-  // Workaround until we've got Transpose operation.
-  assert(input_shape.rank() == 4 || input_shape.rank() == 2);
-  auto softmax = createOp<ops::SoftmaxOp>(layer.name(), input, getAxisValue(opts));
-  return {softmax->getOutput(0)};
+
+  if (cli::debugTranspose) {
+    // CPP and ACL backends are able to perform Softmax only along the last axis.
+    if (inputs[0].op->getOutputShape(inputs[0].index).rank() == 4) {
+      // For now, we only account for the most common case.
+      if (opts.axis() != 1)
+        throw PassException("Softmax: unsupported axis");
+      int32_t axis = 3;
+      auto input = createOp<ops::TransposeOp>(layer.name() + ".trans1", inputs[0],
+                                              std::vector<std::size_t>{0, 2, 3, 1});
+      auto softmax = createOp<ops::SoftmaxOp>(layer.name(), input->getOutput(0), axis);
+      auto result = createOp<ops::TransposeOp>(layer.name() + ".trans2", softmax->getOutput(0),
+                                               std::vector<std::size_t>{0, 3, 1, 2});
+      return {result->getOutput(0)};
+    }
+
+    auto softmax = createOp<ops::SoftmaxOp>(layer.name(), inputs[0], opts.axis());
+    return {softmax->getOutput(0)};
+  } else {
+    auto& input = inputs[0];
+    auto& input_shape = input.op->getOutputShape(input.index);
+    if (opts.axis() != 1)
+      throw PassException("Softmax: unsupported axis");
+    auto softmax = createOp<ops::SoftmaxOp>(layer.name(), inputs[0], -1);
+    return {softmax->getOutput(0)};
+  }
 }
 
 void CaffeOpCreator::checkReshape(const ReshapeParameter& opts,
@@ -468,15 +476,15 @@ CaffeOpCreator::convertScale(const caffe::LayerParameter& layer,
                              const std::vector<mir::IODescriptor>& inputs,
                              const std::vector<std::shared_ptr<IrTensor>>& params) {
   auto& opts = layer.scale_param();
-  auto scale = createOp<ops::ScaleOp>(layer.name(), inputs[0], std::move(*params[0]));
+  auto scale = createOp<ops::ScaleOp>(layer.name(), convertCaffeToMIR(inputs[0]), *params[0]);
 
   // bias_term is optional (so might not be present) and defaults to true
   if (!opts.has_bias_term() || opts.bias_term()) {
     auto bias_add = createOp<ops::BiasAddOp>(layer.name() + ".bias", scale->getOutput(0),
                                              *params[1]);
-    return {bias_add->getOutput(0)};
+    return {convertMIRToCaffe(bias_add->getOutput(0))};
   } else {
-    return {scale->getOutput(0)};
+    return {convertMIRToCaffe(scale->getOutput(0))};
   }
 }
 
@@ -506,7 +514,8 @@ CaffeOpCreator::convertBatchNorm(const caffe::LayerParameter& layer,
 
   for (Index idx: ShapeRange(bias_data.getShape()))
     bias_data.at(idx) *= -scale_factor;
-  auto bias_add = createOp<ops::BiasAddOp>(layer.name() + ".bias", inputs[0], *params[0]);
+  auto bias_add = createOp<ops::BiasAddOp>(layer.name() + ".bias", convertCaffeToMIR(inputs[0]),
+                                           *params[0]);
 
   // create scale argument from variance:
   // multiply elements of variance by scaleFactor and
@@ -516,7 +525,7 @@ CaffeOpCreator::convertBatchNorm(const caffe::LayerParameter& layer,
     scale_data.at(idx) = 1.0f / std::sqrt(scale_data.at(idx) * scale_factor + eps);
   auto scale = createOp<ops::ScaleOp>(layer.name() + ".scale", bias_add->getOutput(0), *params[1]);
 
-  return {scale->getOutput(0)};
+  return {convertMIRToCaffe(scale->getOutput(0))};
 }
 
 std::vector<mir::IODescriptor>
@@ -546,16 +555,16 @@ CaffeOpCreator::convertDeconvolution(const caffe::LayerParameter& layer,
     // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
     unfolded_tensor = fixGroupedKernel(opts.group(), params[0]);
   }
-  auto deconv2d = createOp<ops::DeConv2DOp>(layer.name(), inputs[0], *unfolded_tensor, strides,
-                                            paddings);
+  auto deconv2d = createOp<ops::DeConv2DOp>(layer.name(), convertCaffeToMIR(inputs[0]),
+                                            *unfolded_tensor, strides, paddings);
 
   // bias_term is optional (so might not be present) and defaults to true
   if (!opts.has_bias_term() || opts.bias_term()) {
     auto bias_add = createOp<ops::BiasAddOp>(layer.name() + ".bias", deconv2d->getOutput(0),
                                              *params[1]);
-    return {bias_add->getOutput(0)};
+    return {convertMIRToCaffe(bias_add->getOutput(0))};
   } else {
-    return {deconv2d->getOutput(0)};
+    return {convertMIRToCaffe(deconv2d->getOutput(0))};
   }
 }
 
index bd102f2..55fac40 100644 (file)
@@ -106,8 +106,6 @@ public:
   convertSplit(const caffe::LayerParameter& layer,
                const std::vector<mir::IODescriptor>& inputs);
 
-  void checkConcat(const caffe::ConcatParameter& opts, std::set<std::string>&);
-
   void checkConvolution(const caffe::ConvolutionParameter& layer, std::set<std::string>&);
 
   void checkInnerProduct(const caffe::InnerProductParameter& opts, std::set<std::string>&);
@@ -124,6 +122,10 @@ public:
 private:
   mir::Graph* _graph = nullptr;
 
+  mir::IODescriptor convertCaffeToMIR(const mir::IODescriptor& arg);
+
+  mir::IODescriptor convertMIRToCaffe(const mir::IODescriptor& arg);
+
   template<typename OpType, typename... Types>
   mir::Operation* createOp(const std::string& name, Types&&... args);
 };