[nnc] Caffe2 importer improvements (#2664)
authorIvan Vagin/AI Tools Lab /SRR/Engineer/삼성전자 <ivan.vagin@samsung.com>
Thu, 13 Dec 2018 16:53:10 +0000 (19:53 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 13 Dec 2018 16:53:10 +0000 (19:53 +0300)
* Tested and fixed operators needed for 'inception': Add, Mul, Concat, SpatialBN
* Supported custom paddings
* Supported custom pooling window shapes
* Several bugs fixed

This PR enables 'TransposeOp' in caffe2 frontend, which is not currently supported by ACL backend.

Signed-off-by: Ivan Vagin <ivan.vagin@samsung.com>
contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp
contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp
contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.h

index d9673ae..d5d8116 100644 (file)
@@ -109,21 +109,27 @@ void Caffe2Importer::collectUnsupportedOp(const OperatorDef& op) {
 
   SupportedCaffe2OpType opType = _operatorTypes.at(op.type());
   switch (opType) {
+    case SupportedCaffe2OpType::add:
+      _opCreator->checkAdd(op, _problemsOpSet);
+      break;
     case SupportedCaffe2OpType::FC:
       _opCreator->checkFC(op, _problemsOpSet);
       break;
+    case SupportedCaffe2OpType::mul:
+      _opCreator->checkMul(op, _problemsOpSet);
+      break;
     case SupportedCaffe2OpType::spatialBN:
       _opCreator->checkSpatialBN(op, _problemsOpSet);
       break;
-    case SupportedCaffe2OpType::add:
     case SupportedCaffe2OpType::averagePool:
+    case SupportedCaffe2OpType::conv:
+    case SupportedCaffe2OpType::maxPool:
+      _opCreator->checkConvLikeOp(op, _problemsOpSet);
+      break;
     case SupportedCaffe2OpType::concat:
     case SupportedCaffe2OpType::constantFill:
-    case SupportedCaffe2OpType::conv:
     case SupportedCaffe2OpType::dropout:
     case SupportedCaffe2OpType::givenTensorFill:
-    case SupportedCaffe2OpType::maxPool:
-    case SupportedCaffe2OpType::mul:
     case SupportedCaffe2OpType::relu:
     case SupportedCaffe2OpType::softmax:
     case SupportedCaffe2OpType::sum:
index 57b4652..086a66e 100644 (file)
@@ -59,34 +59,100 @@ using nnc::mir::transposeTensor;
 // Helper functions
 //
 
-mir::IODescriptor Caffe2OpCreator::convertCaffeToMIR(const mir::IODescriptor& arg) {
-  if (cli::debugTranspose) {
-    // NCHW -> NHWC
-    auto transpose = createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 2, 3, 1});
-    return transpose->getOutput(0);
+static std::pair<std::vector<int32_t>, std::vector<int32_t>>
+getPadding(const ::caffe2::OperatorDef& op) {
+  bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r")
+                        || hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b");
+  if (has_custom_pad) {
+    int32_t pad_l = getSingleArgument(op, "pad_l", 0);
+    int32_t pad_t = getSingleArgument(op, "pad_t", 0);
+    int32_t pad_r = getSingleArgument(op, "pad_r", 0);
+    int32_t pad_b = getSingleArgument(op, "pad_b", 0);
+
+    std::vector<int32_t> padding_before{pad_t, pad_l};
+    std::vector<int32_t> padding_after{pad_b, pad_r};
+    return {padding_before, padding_after};
+  }
+
+  int32_t pad = getSingleArgument(op, "pad", 0);
+  return {{pad, pad}, {pad, pad}};
+};
+
+static Shape getWindowShape(const ::caffe2::OperatorDef& op,
+                            const std::vector<IODescriptor>& inputs) {
+  int is_global_pooling = getSingleArgument(op, "global_pooling", 0);
+  bool has_custom_kernel_size = hasArgument(op.arg(), "kernel_h")
+                                || hasArgument(op.arg(), "kernel_w");
+
+  int kernel_h, kernel_w;
+  if (is_global_pooling) {
+    auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
+    assert(input_shape.rank() == 4 && "getWindowShape() inputs must be of rank 4");
+    kernel_h = input_shape.dim(2);
+    kernel_w = input_shape.dim(3);
   } else {
-    return arg;
+    if (has_custom_kernel_size) {
+      kernel_h = getSingleArgument(op, "kernel_h", 0);
+      kernel_w = getSingleArgument(op, "kernel_w", 0);
+    } else {
+      kernel_h = kernel_w = getSingleArgument(op, "kernel", 0);
+    }
   }
+  return Shape({kernel_h, kernel_w});
+}
+
+mir::IODescriptor Caffe2OpCreator::convertCaffeToMIR(const mir::IODescriptor& arg) {
+  // NCHW -> NHWC
+  auto transpose = createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 2, 3, 1});
+  return transpose->getOutput(0);
 }
 
 mir::IODescriptor Caffe2OpCreator::convertMIRToCaffe(const mir::IODescriptor& arg) {
-  if (cli::debugTranspose) {
-    // NHWC -> NCHW
-    auto transpose = createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 3, 1, 2});
-    return transpose->getOutput(0);
-  } else {
-    return arg;
-  }
+  // NHWC -> NCHW
+  auto transpose = createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 3, 1, 2});
+  return transpose->getOutput(0);
 }
 
 //
 // Check functions
 //
 
-void Caffe2OpCreator::commonCheck(const ::caffe2::OperatorDef& op,
-                                  std::set<std::string>& problemsOpSet) {
-  if (getSingleArgument(op, "order", "NCHW") != "NCHW")
-    problemsOpSet.insert("Only 'NCHW' oreder is supported");
+void Caffe2OpCreator::checkAdd(const ::caffe2::OperatorDef& op,
+                               std::set<std::string>& problemsOpSet) {
+  commonCheck(op, problemsOpSet);
+
+  if (getSingleArgument(op, "axis", 1) != 1)
+    problemsOpSet.insert("Add: only 'axis' = 1 is supported");
+
+  if (getSingleArgument(op, "broadcast", 1) != 1)
+    problemsOpSet.insert("Add: only enabled 'broadcast' is supported");
+}
+
+void Caffe2OpCreator::checkConvLikeOp(const ::caffe2::OperatorDef& op,
+                                      std::set<std::string>& problemsOpSet) {
+  commonCheck(op, problemsOpSet);
+
+  // Padding
+  bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r")
+                        || hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b");
+
+  if (has_custom_pad && hasArgument(op.arg(), "pad"))
+    problemsOpSet.insert("Custom pad can't be combined with overall pad");
+
+  if (has_custom_pad && !(hasArgument(op.arg(), "pad_l") && hasArgument(op.arg(), "pad_r")
+                          && hasArgument(op.arg(), "pad_t") && hasArgument(op.arg(), "pad_b")))
+    problemsOpSet.insert("If one custom pad specified - all custom pads must be specified");
+
+  // Kernel size
+  bool has_custom_kernel_size = hasArgument(op.arg(), "kernel_h")
+                                || hasArgument(op.arg(), "kernel_w");
+
+  if (has_custom_kernel_size && hasArgument(op.arg(), "kernel"))
+    problemsOpSet.insert("Custom kernel size can't be combined with overall kernel size");
+
+  if (has_custom_kernel_size && !(hasArgument(op.arg(), "kernel_h")
+                                  && hasArgument(op.arg(), "kernel_w")))
+    problemsOpSet.insert("If one custom kernel size specified - all custom kernel sizes must be specified");
 }
 
 void Caffe2OpCreator::checkFC(const ::caffe2::OperatorDef& op,
@@ -97,12 +163,32 @@ void Caffe2OpCreator::checkFC(const ::caffe2::OperatorDef& op,
       problemsOpSet.insert(std::string("FC: only default '") + s + "' value is supported");
 }
 
+void Caffe2OpCreator::checkMul(const ::caffe2::OperatorDef& op,
+                               std::set<std::string>& problemsOpSet) {
+  commonCheck(op, problemsOpSet);
+
+  if (getSingleArgument(op, "axis", 1) != 1)
+    problemsOpSet.insert("Mul: only 'axis' = 1 is supported");
+
+  if (getSingleArgument(op, "broadcast", 1) != 1)
+    problemsOpSet.insert("Mul: only enabled 'broadcast' is supported");
+}
+
 void Caffe2OpCreator::checkSpatialBN(const ::caffe2::OperatorDef& op,
                                      std::set<std::string>& problemsOpSet) {
   commonCheck(op, problemsOpSet);
   if (op.input_size() != 5)
     problemsOpSet.insert(
             "SpatialBN must have exactly 5 inputs ('sums' and 'sumsq' are not supported yet)");
+
+  if (getSingleArgument(op, "is_test", 1) != 1)
+    problemsOpSet.insert(std::string("SpatialBN: only test mode supported"));
+}
+
+void Caffe2OpCreator::commonCheck(const ::caffe2::OperatorDef& op,
+                                  std::set<std::string>& problemsOpSet) {
+  if (getSingleArgument(op, "order", "NCHW") != "NCHW")
+    problemsOpSet.insert("Only 'NCHW' oreder is supported");
 }
 
 //
@@ -112,51 +198,50 @@ void Caffe2OpCreator::checkSpatialBN(const ::caffe2::OperatorDef& op,
 std::vector<mir::IODescriptor>
 Caffe2OpCreator::convertAdd(const std::vector<mir::IODescriptor>& inputs,
                             const ::caffe2::OperatorDef& op,
-                            const MIRTensors& mirTensors) {
-  // TODO: not tested
-  throw PassException("Caffe2 Add op not tested yet");
-  auto& addend = mirTensors.at(op.input(1));
-  auto add = createOp<ops::BiasAddOp>(inputs[0], *addend);
-  return {add->getOutput(0)};
+                            const MIRTensors& mir_tensors) {
+  auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
+  auto& addend = mir_tensors.at(op.input(1));
+
+  assert(addend->getShape().rank() == 1 && "Only 1-rank addend is supported");
+  assert(addend->getShape().numElements() == input_shape.dim(1)
+         && "Only addend size equal to number of input channels is supported");
+
+  // TODO: replace with elementwise op, when broadcating will be added in elementwise op
+  auto add = createOp<ops::BiasAddOp>(convertCaffeToMIR(inputs[0]), *addend);
+  return {convertMIRToCaffe(add->getOutput(0))};
 }
 
 std::vector<IODescriptor>
 Caffe2OpCreator::convertAveragePool(const std::vector<IODescriptor>& inputs,
                                     const OperatorDef& op) {
-  // TODO: implement custom paddings
-  bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r")
-                        || hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b");
-  if (has_custom_pad)
-    throw PassException("Custom one-side padding not supported yet");
-
-  int kernel_size = static_cast<int>(findArgumentByName(op.arg(), "kernel").i());
-  Shape window_shape = Shape({kernel_size, kernel_size});
+  Shape window_shape = getWindowShape(op, inputs);
 
-  int stride = static_cast<int>(findArgumentByName(op.arg(), "stride").i());
+  int stride = getSingleArgument(op, "stride", 0);
   Shape strides = Shape({stride, stride});
 
   ops::PoolOp::PoolingType pool_type = ops::PoolOp::PoolingType::AVG;
-  ops::PoolOp::BorderType border_type = ops::PoolOp::BorderType::ZEROFILLED;
+  ops::PoolOp::BorderType border_type = ops::PoolOp::BorderType::EMPTY;
 
-  int pad = getSingleArgument(op, "pad", 0);
-  std::vector<int32_t> padding{pad, pad};
+  std::vector<int32_t> pad_before, pad_after;
+  std::tie(pad_before, pad_after) = getPadding(op);
 
-  auto pooling = createOp<ops::PoolOp>(inputs[0], pool_type, window_shape, strides, padding,
-                                       padding, border_type, ops::PoolOp::RoundMode::ceil);
+  auto pooling = createOp<ops::PoolOp>(convertCaffeToMIR(inputs[0]), pool_type, window_shape,
+                                       strides, pad_before, pad_after, border_type,
+                                       ops::PoolOp::RoundMode::floor);
 
-  return {pooling->getOutput(0)};
+  return {convertMIRToCaffe(pooling->getOutput(0))};
 }
 
 std::vector<IODescriptor> Caffe2OpCreator::convertConv(const std::vector<IODescriptor>& inputs,
                                                        const ::caffe2::OperatorDef& op,
-                                                       const MIRTensors& mirTensors) {
+                                                       const MIRTensors& mir_tensors) {
   int stride = getSingleArgument(op, "stride", 1);
   Shape stride_shape = Shape({stride, stride});
 
-  int pad = getSingleArgument(op, "pad", 0);
-  std::vector<int32_t> padding{pad, pad};
+  std::vector<int32_t> pad_before, pad_after;
+  std::tie(pad_before, pad_after) = getPadding(op);
 
-  auto kernel_tensor = transposeTensor<2, 3, 1, 0>(mirTensors.at(op.input(1)));
+  auto kernel_tensor = transposeTensor<2, 3, 1, 0>(mir_tensors.at(op.input(1)));
   auto in_group_size = kernel_tensor->getShape().dim(2);
   auto out_channels = kernel_tensor->getShape().dim(3);
   int num_groups = getSingleArgument(op, "group", 1);
@@ -164,22 +249,21 @@ std::vector<IODescriptor> Caffe2OpCreator::convertConv(const std::vector<IODescr
 
   mir::Operation* conv2d;
   if (is_depthwise) {
-    // This is depthwise convolution
     // TODO handle properly kernel with layer multiplier
     std::shared_ptr<IrTensor> transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(kernel_tensor);
     conv2d = createOp<ops::DepthwiseConv2DOp>(convertCaffeToMIR(inputs[0]), *transposed_tensor,
-                                              stride_shape, padding, padding);
+                                              stride_shape, pad_before, pad_after);
   } else {
     // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
     if (num_groups != 1)
       kernel_tensor = fixGroupedKernel(num_groups, kernel_tensor);
 
     conv2d = createOp<ops::Conv2DOp>(convertCaffeToMIR(inputs[0]), *kernel_tensor,
-                                     stride_shape, padding, padding);
+                                     stride_shape, pad_before, pad_after);
   }
 
   if (op.input_size() > 2) {  // Bias is optional
-    auto bias_add = createOp<ops::BiasAddOp>(conv2d->getOutput(0), *mirTensors.at(op.input(2)));
+    auto bias_add = createOp<ops::BiasAddOp>(conv2d->getOutput(0), *mir_tensors.at(op.input(2)));
     return {convertMIRToCaffe(bias_add->getOutput(0))};
   }
   return {convertMIRToCaffe(conv2d->getOutput(0))};
@@ -187,80 +271,55 @@ std::vector<IODescriptor> Caffe2OpCreator::convertConv(const std::vector<IODescr
 
 std::vector<IODescriptor> Caffe2OpCreator::convertConcat(const std::vector<IODescriptor>& inputs,
                                                          const ::caffe2::OperatorDef& op) {
-  // TODO: not tested
-  throw PassException("Caffe2 Concat op not tested yet");
-  int axis = getSingleArgument(op, "axis", -1);
+  int axis = getSingleArgument(op, "axis", 1);
   auto result = createOp<ops::ConcatOp>(inputs, axis);
   return {result->getOutput(0)};
 }
 
 std::vector<IODescriptor> Caffe2OpCreator::convertDropout(const std::vector<IODescriptor>& inputs,
                                                           const ::caffe2::OperatorDef& op) {
-  // TODO: not tested
-  throw PassException("Caffe2 Dropout op not tested yet");
   int is_test = getSingleArgument(op, "is_test", 0);
   if (is_test)
     return {inputs[0]};
 
-  float dropot_ratio = getSingleArgument(op, "ratio", 0.5f);
-  auto dropout = createOp<ops::DropoutOp>(inputs[0], dropot_ratio);
+  float dropout_ratio = getSingleArgument(op, "ratio", 0.5f);
+  auto dropout = createOp<ops::DropoutOp>(inputs[0], dropout_ratio);
   return {dropout->getOutput(0)};
 }
 
-// TODO: describe caffe2 FC interface
 std::vector<IODescriptor>
 Caffe2OpCreator::convertFullyConnected(const std::vector<IODescriptor>& inputs,
                                        const ::caffe2::OperatorDef& op,
-                                       const MIRTensors& mirTensors) {
-  auto weightsTensor = mirTensors.at(op.input(1));
-  weightsTensor = transposeTensor<1, 0>(weightsTensor);
-  int32_t fc_input_size = weightsTensor->getShape().dim(0);
-
-  // Add Reshape operation to make sure the input for FC operation has shape [1, fcInputSize]
-  // It is needed because Caffe2 FC layer takes NCHW input and flattens the CHW part.
-  auto reshape = createOp<ops::ReshapeOp>(inputs[0], Shape({1, fc_input_size}));
+                                       const MIRTensors& mir_tensors) {
+  auto weights_tensor = mir_tensors.at(op.input(1));
+  weights_tensor = transposeTensor<1, 0>(weights_tensor);
 
-  auto fully_connected = createOp<ops::FullyConnectedOp>(reshape->getOutput(0), *weightsTensor);
+  auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
+  // Transform input into 2-D tensor by flattening axes
+  Shape shape{input_shape.dim(0), input_shape.numElements() / input_shape.dim(0)};
+  auto reshape = createOp<ops::ReshapeOp>(inputs[0], shape);
+  auto fully_connected = createOp<ops::FullyConnectedOp>(reshape->getOutput(0), *weights_tensor);
 
-  auto bias = createOp<ops::BiasAddOp>(fully_connected->getOutput(0), *mirTensors.at(op.input(2)));
+  auto bias = createOp<ops::BiasAddOp>(fully_connected->getOutput(0), *mir_tensors.at(op.input(2)));
   return {bias->getOutput(0)};
 }
 
-std::vector<IODescriptor>
-Caffe2OpCreator::createInput(const std::string& input_name, const mir::Shape& input_shape) {
-  // TODO For now we only support convolutional networks with one element per batch.
-  assert(input_shape.rank() == 4 && input_shape.dim(0) == 1);
-
-  // TODO Do not transpose data on input and remove transpose.
-  auto transposed_shape = mir::Shape{input_shape.dim(0), input_shape.dim(2),
-                                     input_shape.dim(3), input_shape.dim(1)};
-  auto variable = _graph->create<ops::VariableOp>(input_name, transposed_shape);
-  return {convertMIRToCaffe(variable->getOutput(0))};
-}
-
 std::vector<IODescriptor> Caffe2OpCreator::convertMaxPool(const std::vector<IODescriptor>& inputs,
                                                           const OperatorDef& op) {
-  // TODO: implement custom paddings
-  bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r")
-                        || hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b");
-  if (has_custom_pad)
-    throw PassException("Custom one-side padding not supported yet");
-
-  int window_length = static_cast<int>(findArgumentByName(op.arg(), "kernel").i());
-  Shape window_shape = Shape({window_length, window_length});
+  Shape window_shape = getWindowShape(op, inputs);
 
-  int stride = static_cast<int>(findArgumentByName(op.arg(), "stride").i());
+  int stride = getSingleArgument(op, "stride", 0);
   Shape strides = Shape({stride, stride});
 
   ops::PoolOp::PoolingType pool_type = ops::PoolOp::PoolingType::MAX;
-  ops::PoolOp::BorderType border_type = ops::PoolOp::BorderType::EMPTY;
+  ops::PoolOp::BorderType border_type = ops::PoolOp::BorderType::ZEROFILLED;
 
-  int pad = getSingleArgument(op, "pad", 0);
-  std::vector<int32_t> padding{pad, pad};
+  std::vector<int32_t> pad_before, pad_after;
+  std::tie(pad_before, pad_after) = getPadding(op);
 
   auto pooling = createOp<ops::PoolOp>(convertCaffeToMIR(inputs[0]), pool_type, window_shape,
-                                       strides, padding, padding, border_type,
-                                       ops::PoolOp::RoundMode::ceil);
+                                       strides, pad_before, pad_after, border_type,
+                                       ops::PoolOp::RoundMode::floor);
 
   return {convertMIRToCaffe(pooling->getOutput(0))};
 }
@@ -268,12 +327,17 @@ std::vector<IODescriptor> Caffe2OpCreator::convertMaxPool(const std::vector<IODe
 std::vector<mir::IODescriptor>
 Caffe2OpCreator::convertMul(const std::vector<mir::IODescriptor>& inputs,
                             const ::caffe2::OperatorDef& op,
-                            const MIRTensors& mirTensors) {
-  // TODO: not tested
-  throw PassException("Caffe Mul op not tested yet");
-  auto& multiplier = mirTensors.at(op.input(1));
-  auto mul = createOp<ops::ScaleOp>(inputs[0], *multiplier);
-  return {mul->getOutput(0)};
+                            const MIRTensors& mir_tensors) {
+  auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
+  auto& multiplier = mir_tensors.at(op.input(1));
+
+  assert(multiplier->getShape().rank() == 1 && "Only 1-rank multiplier is supported");
+  assert(multiplier->getShape().numElements() == input_shape.dim(1)
+         && "Only multiplier size equal to number of input channels is supported");
+
+  // TODO: replace with elementwise op, when broadcating will be added in elementwise op
+  auto mul = createOp<ops::ScaleOp>(convertCaffeToMIR(inputs[0]), *multiplier);
+  return {convertMIRToCaffe(mul->getOutput(0))};
 }
 
 std::vector<IODescriptor> Caffe2OpCreator::convertRelu(const std::vector<IODescriptor>& inputs) {
@@ -291,27 +355,25 @@ std::vector<IODescriptor> Caffe2OpCreator::convertSoftmax(const std::vector<IODe
 std::vector<mir::IODescriptor>
 Caffe2OpCreator::convertSpatialBN(const std::vector<mir::IODescriptor>& inputs,
                                   const ::caffe2::OperatorDef& op,
-                                  const MIRTensors& mirTensors) {
-  // TODO: not tested
-  throw PassException("Caffe2 SpatialBN op not tested yet");
+                                  const MIRTensors& mir_tensors) {
   // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias
 
-  auto& scale = mirTensors.at(op.input(1));
-  auto& bias = mirTensors.at(op.input(2));
-  auto& mean = mirTensors.at(op.input(3));
-  auto& var = mirTensors.at(op.input(4));
+  auto& scale = mir_tensors.at(op.input(1));
+  auto& bias = mir_tensors.at(op.input(2));
+  auto& mean = mir_tensors.at(op.input(3));
+  auto& var = mir_tensors.at(op.input(4));
   float eps = getSingleArgument(op, "epsilon", 1e-5f);
 
   // res1 = X - mean
   Tensor<float> bias_data(*mean);
-  for (Index idx: ShapeRange(bias_data.getShape()))
+  for (auto& idx: ShapeRange(bias_data.getShape()))
     bias_data.at(idx) *= -1;
   auto bias_add_1 = createOp<ops::BiasAddOp>(convertCaffeToMIR(inputs[0]), *mean);
 
   // res2 = res1 * scale / (var + epsilon)
   Tensor<float> multiplier(*scale);
-  for (Index idx: ShapeRange(scale->getShape()))
-    multiplier.at(idx) = 1.0f / std::sqrt(*(float*) var->at(idx) + eps);
+  for (auto& idx: ShapeRange(scale->getShape()))
+    multiplier.at(idx) /= std::sqrt(*(float*) var->at(idx) + eps);
   auto scale_op = createOp<ops::ScaleOp>(bias_add_1->getOutput(0), *scale);
 
   // overall_res = res2 + bias
@@ -321,8 +383,24 @@ Caffe2OpCreator::convertSpatialBN(const std::vector<mir::IODescriptor>& inputs,
 }
 
 std::vector<IODescriptor> Caffe2OpCreator::convertSum(const std::vector<IODescriptor>& inputs) {
+  auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
+  for (auto& in : inputs)
+    assert(input_shape == in.op->getOutputShape(inputs[0].index) && "All Sum inputs must have same shape");
+
   auto op = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::add);
   return {op->getOutput(0)};
 }
 
+std::vector<IODescriptor>
+Caffe2OpCreator::createInput(const std::string& input_name, const mir::Shape& input_shape) {
+  // TODO For now we only support convolutional networks with one element per batch.
+  assert(input_shape.rank() == 4 && input_shape.dim(0) == 1);
+
+  // TODO Do not transpose data on input and remove transpose.
+  auto transposed_shape = mir::Shape{input_shape.dim(0), input_shape.dim(2),
+                                     input_shape.dim(3), input_shape.dim(1)};
+  auto variable = _graph->create<ops::VariableOp>(input_name, transposed_shape);
+  return {convertMIRToCaffe(variable->getOutput(0))};
+}
+
 } // namespace nnc
index 09b2d0e..fdff757 100644 (file)
@@ -42,12 +42,18 @@ class Caffe2OpCreator {
 public:
   explicit Caffe2OpCreator(Graph* g) : _graph(g) {};
 
-  void commonCheck(const ::caffe2::OperatorDef&, std::set<std::string>&);
+  void checkAdd(const ::caffe2::OperatorDef&, std::set<std::string>&);
+
+  void checkConvLikeOp(const ::caffe2::OperatorDef&, std::set<std::string>&);
 
   void checkFC(const ::caffe2::OperatorDef&, std::set<std::string>&);
 
+  void checkMul(const ::caffe2::OperatorDef&, std::set<std::string>&);
+
   void checkSpatialBN(const ::caffe2::OperatorDef&, std::set<std::string>&);
 
+  void commonCheck(const ::caffe2::OperatorDef&, std::set<std::string>&);
+
   std::vector<mir::IODescriptor> convertAdd(const std::vector<mir::IODescriptor>&,
                                             const ::caffe2::OperatorDef&, const MIRTensors&);