[mir_caffe2] Do not set names of operations (#6812)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Wed, 28 Aug 2019 07:53:09 +0000 (16:53 +0900)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 28 Aug 2019 07:53:09 +0000 (10:53 +0300)
Remove useless setting of operation names.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir-caffe2-importer/caffe2_importer.cpp
compiler/mir-caffe2-importer/caffe2_importer.h
compiler/mir-caffe2-importer/caffe2_op_creator.cpp
compiler/mir-caffe2-importer/caffe2_op_creator.h

index 7432589..b8b6f41 100644 (file)
@@ -19,6 +19,7 @@
 #include "caffe2_op_creator.h"
 #include "caffe2_proto_helper.h"
 
+#include "mir/ops/InputOp.h"
 #include "mir/ops/OutputOp.h"
 
 #include <google/protobuf/io/zero_copy_stream_impl.h>
@@ -96,10 +97,11 @@ std::unique_ptr<mir::Graph> Caffe2Importer::createIR()
 
   // Create inputs. This has to be done after processing initializers, because they may contain
   // fake inputs.
-  // TODO Caffe2 does not provide a way to detect model inputs. For now assume that the first input
-  // of the first operation is the only input to the model.
+  // TODO Caffe2 does not provide a way to detect model inputs and outputs. For now assume that:
+  //      - there is exactly one input;
+  //      - the input is for the first layer.
   const auto &input_name = _predict_net->op(0).input(0);
-  auto input = _opCreator->createInput(input_name, _inputShapes[0]);
+  auto input = _graph->create<mir::ops::InputOp>(_inputShapes[0])->getOutput(0);
   setOutputForTensor(input_name, input);
 
   for (const auto &op : _predict_net->op())
@@ -204,12 +206,6 @@ void Caffe2Importer::createMIRNodesFromOp(const OperatorDef &op)
   {
     setOutputForTensor(op.output(i), outputs[i]);
   }
-
-  // `outputs` can be empty if constant input was not processed.
-  if (!outputs.empty())
-  {
-    _lastMIROp = outputs.at(0)->getNode();
-  }
 }
 
 std::vector<mir::Operation::Output *> Caffe2Importer::getInputMIROps(const OperatorDef &op)
@@ -244,10 +240,13 @@ mir::Operation::Output *Caffe2Importer::getOutputForTensor(const std::string &na
 
 void Caffe2Importer::setGraphOutputs()
 {
-  // For now, we assume that:
-  //   - there is exactly one output;
-  //   - the output is from the last layer.
-  _graph->create<mir::ops::OutputOp>("out", _lastMIROp->getOutput(0));
+  // Create outputs.
+  // TODO Caffe2 does not provide a way to detect model inputs and outputs. For now assume that:
+  //      - there is exactly one output;
+  //      - the output is from the last layer.
+  const auto &output_name = _predict_net->op().rbegin()->output(0);
+  auto output = getOutputForTensor(output_name);
+  _graph->create<mir::ops::OutputOp>(output);
 }
 
 const std::map<std::string, SupportedCaffe2OpType> Caffe2Importer::_operatorTypes = {
index ccff1c9..9af8387 100644 (file)
@@ -52,7 +52,6 @@ private:
 
   // Maps Caffe2 operator input names to corresponding MIR operation outputs.
   std::unordered_map<std::string, mir::Operation::Output *> _blobNameToOutput;
-  mir::Operation *_lastMIROp = nullptr;
 
   void import();
   std::unique_ptr<mir::Graph> createIR();
index 4242ab7..430fbad 100644 (file)
@@ -24,7 +24,6 @@
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
 #include "mir/ops/FullyConnectedOp.h"
-#include "mir/ops/InputOp.h"
 #include "mir/ops/MulOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReluOp.h"
@@ -217,16 +216,14 @@ static Shape getWindowShape(const ::caffe2::OperatorDef &op,
 mir::Operation::Output *Caffe2OpCreator::convertCaffeToMIR(mir::Operation::Output *arg)
 {
   // NCHW -> NHWC
-  auto transpose =
-      createOp<ops::TransposeOp>("CaffeToMIR", arg, std::vector<std::size_t>{0, 2, 3, 1});
+  auto transpose = createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 2, 3, 1});
   return transpose->getOutput(0);
 }
 
 mir::Operation::Output *Caffe2OpCreator::convertMIRToCaffe(mir::Operation::Output *arg)
 {
   // NHWC -> NCHW
-  auto transpose =
-      createOp<ops::TransposeOp>("MIRToCaffe", arg, std::vector<std::size_t>{0, 3, 1, 2});
+  auto transpose = createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 3, 1, 2});
   return transpose->getOutput(0);
 }
 
@@ -320,7 +317,7 @@ Caffe2OpCreator::convertConstant(const std::vector<mir::Operation::Output *> &in
   if (!hasArgument(op.arg(), "values"))
     return {};
 
-  return {createOp<ops::ConstantOp>("", createTensor(op))->getOutput(0)};
+  return {createOp<ops::ConstantOp>(createTensor(op))->getOutput(0)};
 }
 
 std::vector<mir::Operation::Output *>
@@ -330,12 +327,12 @@ Caffe2OpCreator::convertAdd(const std::vector<mir::Operation::Output *> &inputs,
   if (getSingleArgument(op, "broadcast", 0) != 0)
   {
     // FIXME This only works when 'axis' == 1 and the second input is 1-D.
-    auto result = createOp<ops::AddOp>("", convertCaffeToMIR(inputs[0]), inputs[1])->getOutput(0);
+    auto result = createOp<ops::AddOp>(convertCaffeToMIR(inputs[0]), inputs[1])->getOutput(0);
 
     return {convertMIRToCaffe(result)};
   }
 
-  auto result = createOp<ops::AddOp>("", inputs[0], inputs[1])->getOutput(0);
+  auto result = createOp<ops::AddOp>(inputs[0], inputs[1])->getOutput(0);
   return {result};
 }
 std::vector<mir::Operation::Output *>
@@ -354,8 +351,8 @@ Caffe2OpCreator::convertAveragePool(const std::vector<mir::Operation::Output *>
   std::vector<int32_t> pad_before, pad_after;
   std::tie(pad_before, pad_after) = getPadding(op);
 
-  auto pooling = createOp<ops::PoolOp>("Average_Pool", convertCaffeToMIR(inputs[0]), pool_type,
-                                       window_shape, strides, pad_before, pad_after, border_type);
+  auto pooling = createOp<ops::PoolOp>(convertCaffeToMIR(inputs[0]), pool_type, window_shape,
+                                       strides, pad_before, pad_after, border_type);
 
   return {convertMIRToCaffe(pooling->getOutput(0))};
 }
@@ -386,9 +383,9 @@ Caffe2OpCreator::convertConv(const std::vector<mir::Operation::Output *> &inputs
   {
     // TODO handle properly kernel with layer multiplier
     auto transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(kernel_tensor);
-    auto kernel = createOp<ops::ConstantOp>("Constant", transposed_tensor)->getOutput(0);
-    result = createOp<ops::DepthwiseConv2DOp>("Depthwise_Conv2D", convertCaffeToMIR(inputs[0]),
-                                              kernel, stride_shape, pad_before, pad_after)
+    auto kernel = createOp<ops::ConstantOp>(transposed_tensor)->getOutput(0);
+    result = createOp<ops::DepthwiseConv2DOp>(convertCaffeToMIR(inputs[0]), kernel, stride_shape,
+                                              pad_before, pad_after)
                  ->getOutput(0);
   }
   else
@@ -397,15 +394,15 @@ Caffe2OpCreator::convertConv(const std::vector<mir::Operation::Output *> &inputs
     if (num_groups != 1)
       kernel_tensor = fixGroupedKernel(num_groups, kernel_tensor);
     kernel_tensor = transposeTensor<3, 0, 1, 2>(kernel_tensor);
-    auto kernel = createOp<ops::ConstantOp>("Constant", kernel_tensor)->getOutput(0);
-    result = createOp<ops::Conv2DOp>("Conv2D", convertCaffeToMIR(inputs[0]), kernel, stride_shape,
-                                     pad_before, pad_after)
+    auto kernel = createOp<ops::ConstantOp>(kernel_tensor)->getOutput(0);
+    result = createOp<ops::Conv2DOp>(convertCaffeToMIR(inputs[0]), kernel, stride_shape, pad_before,
+                                     pad_after)
                  ->getOutput(0);
   }
 
   if (op.input_size() > 2)
   {
-    result = createOp<ops::AddOp>("", result, inputs[2])->getOutput(0);
+    result = createOp<ops::AddOp>(result, inputs[2])->getOutput(0);
   }
 
   return {convertMIRToCaffe(result)};
@@ -419,7 +416,7 @@ Caffe2OpCreator::convertConcat(const std::vector<mir::Operation::Output *> &inpu
 
   // `1` corresponds to the default (channels) axis.
   int axis = getSingleArgument(op, "axis", 1);
-  auto result = createOp<ops::ConcatOp>("Concat", inputs, axis);
+  auto result = createOp<ops::ConcatOp>(inputs, axis);
   return {result->getOutput(0)};
 }
 
@@ -449,10 +446,10 @@ Caffe2OpCreator::convertFC(const std::vector<mir::Operation::Output *> &inputs,
   // Transform input into 2-D tensor by flattening axes
   Shape shape{input_shape.dim(0), input_shape.numElements() / input_shape.dim(0)};
 
-  auto reshape = createOp<ops::ReshapeOp>("Reshape", inputs[0], shape)->getOutput(0);
-  auto weights = createOp<ops::ConstantOp>("Constant", weights_tensor)->getOutput(0);
-  auto result = createOp<ops::FullyConnectedOp>("Fully_Connected", reshape, weights)->getOutput(0);
-  result = createOp<ops::AddOp>("", result, inputs[2])->getOutput(0);
+  auto reshape = createOp<ops::ReshapeOp>(inputs[0], shape)->getOutput(0);
+  auto weights = createOp<ops::ConstantOp>(weights_tensor)->getOutput(0);
+  auto result = createOp<ops::FullyConnectedOp>(reshape, weights)->getOutput(0);
+  result = createOp<ops::AddOp>(result, inputs[2])->getOutput(0);
 
   return {result};
 }
@@ -472,8 +469,8 @@ Caffe2OpCreator::convertMaxPool(const std::vector<mir::Operation::Output *> &inp
   std::vector<int32_t> pad_before, pad_after;
   std::tie(pad_before, pad_after) = getPadding(op);
 
-  auto pooling = createOp<ops::PoolOp>("Pool", convertCaffeToMIR(inputs[0]), pool_type,
-                                       window_shape, strides, pad_before, pad_after, border_type);
+  auto pooling = createOp<ops::PoolOp>(convertCaffeToMIR(inputs[0]), pool_type, window_shape,
+                                       strides, pad_before, pad_after, border_type);
 
   return {convertMIRToCaffe(pooling->getOutput(0))};
 }
@@ -485,19 +482,19 @@ Caffe2OpCreator::convertMul(const std::vector<mir::Operation::Output *> &inputs,
   if (getSingleArgument(op, "broadcast", 0) != 0)
   {
     // FIXME This only works when `axis` == 1 and the second input is 1-D.
-    auto result = createOp<ops::MulOp>("", convertCaffeToMIR(inputs[0]), inputs[1])->getOutput(0);
+    auto result = createOp<ops::MulOp>(convertCaffeToMIR(inputs[0]), inputs[1])->getOutput(0);
 
     return {convertMIRToCaffe(result)};
   }
 
-  auto result = createOp<ops::MulOp>("", inputs[0], inputs[1])->getOutput(0);
+  auto result = createOp<ops::MulOp>(inputs[0], inputs[1])->getOutput(0);
   return {result};
 }
 
 std::vector<mir::Operation::Output *>
 Caffe2OpCreator::convertRelu(const std::vector<mir::Operation::Output *> &inputs)
 {
-  auto relu = createOp<ops::ReluOp>("Relu", inputs[0]);
+  auto relu = createOp<ops::ReluOp>(inputs[0]);
   return {relu->getOutput(0)};
 }
 
@@ -513,7 +510,7 @@ Caffe2OpCreator::convertResizeNearest(const std::vector<mir::Operation::Output *
   scales[1] = getSingleArgument(op, "height_scale", 1.0f);
   scales[2] = getSingleArgument(op, "width_scale", 1.0f);
   scales[3] = 1;
-  auto resize = createOp<ops::ResizeOp>("ResizeNearest", convertCaffeToMIR(inputs[0]),
+  auto resize = createOp<ops::ResizeOp>(convertCaffeToMIR(inputs[0]),
                                         ops::ResizeOp::ResizeMethod::nearestNeighbor, scales);
   return {convertMIRToCaffe(resize->getOutput(0))};
 }
@@ -521,7 +518,7 @@ Caffe2OpCreator::convertResizeNearest(const std::vector<mir::Operation::Output *
 std::vector<mir::Operation::Output *>
 Caffe2OpCreator::convertSigmoid(const std::vector<mir::Operation::Output *> &inputs)
 {
-  auto result = createOp<ops::SigmoidOp>("Sigmoid", inputs[0]);
+  auto result = createOp<ops::SigmoidOp>(inputs[0]);
   return {result->getOutput(0)};
 }
 
@@ -530,7 +527,7 @@ Caffe2OpCreator::convertSoftmax(const std::vector<mir::Operation::Output *> &inp
                                 const ::caffe2::OperatorDef &op)
 {
   int axis = getSingleArgument(op, "axis", 1);
-  auto softmax = createOp<ops::SoftmaxOp>("Softmax", inputs[0], axis);
+  auto softmax = createOp<ops::SoftmaxOp>(inputs[0], axis);
   return {softmax->getOutput(0)};
 }
 
@@ -568,19 +565,19 @@ Caffe2OpCreator::convertSpatialBN(const std::vector<mir::Operation::Output *> &i
   for (auto &idx : ShapeRange(bias_data.getShape()))
     bias_data.at(idx) *= -1;
 
-  auto mean = createOp<ops::ConstantOp>("Constant", mean_tensor)->getOutput(0);
-  auto result = createOp<ops::AddOp>("", convertCaffeToMIR(inputs[0]), mean)->getOutput(0);
+  auto mean = createOp<ops::ConstantOp>(mean_tensor)->getOutput(0);
+  auto result = createOp<ops::AddOp>(convertCaffeToMIR(inputs[0]), mean)->getOutput(0);
 
   // res2 = res1 * scale / (var + epsilon)
   Tensor<float> multiplier(scale_tensor);
   for (auto &idx : ShapeRange(scale_tensor.getShape()))
     multiplier.at(idx) /= std::sqrt(*reinterpret_cast<float *>(var_tensor.at(idx)) + eps);
-  auto scale = createOp<ops::ConstantOp>("Constant", scale_tensor)->getOutput(0);
-  result = createOp<ops::MulOp>("", result, scale)->getOutput(0);
+  auto scale = createOp<ops::ConstantOp>(scale_tensor)->getOutput(0);
+  result = createOp<ops::MulOp>(result, scale)->getOutput(0);
 
   // overall_res = res2 + bias
-  auto bias = createOp<ops::ConstantOp>("Constant", bias_tensor)->getOutput(0);
-  result = createOp<ops::AddOp>("", result, bias)->getOutput(0);
+  auto bias = createOp<ops::ConstantOp>(bias_tensor)->getOutput(0);
+  result = createOp<ops::AddOp>(result, bias)->getOutput(0);
 
   return {convertMIRToCaffe(result)};
 }
@@ -588,10 +585,10 @@ Caffe2OpCreator::convertSpatialBN(const std::vector<mir::Operation::Output *> &i
 std::vector<mir::Operation::Output *>
 Caffe2OpCreator::convertSum(const std::vector<mir::Operation::Output *> &inputs)
 {
-  auto result = createOp<ops::AddOp>("", inputs[0], inputs[1])->getOutput(0);
+  auto result = createOp<ops::AddOp>(inputs[0], inputs[1])->getOutput(0);
   for (int i = 2; i < static_cast<int>(inputs.size()); ++i)
   {
-    result = createOp<ops::AddOp>("", result, inputs[i])->getOutput(0);
+    result = createOp<ops::AddOp>(result, inputs[i])->getOutput(0);
   }
   return {result};
 }
@@ -605,7 +602,7 @@ Caffe2OpCreator::convertClip(const std::vector<mir::Operation::Output *> &inputs
   float min = getSingleArgument(op, "min", float(0));
 
   assert(max > 0.0 && min == 0.0 && "Support only if clip is CappedRelu");
-  auto cap_relu = createOp<ops::CappedReluOp>("Capped_Relu", inputs[0], max);
+  auto cap_relu = createOp<ops::CappedReluOp>(inputs[0], max);
 
   return {cap_relu->getOutput(0)};
 }
@@ -630,14 +627,9 @@ Caffe2OpCreator::convertReshape(const std::vector<mir::Operation::Output *> &inp
   }
   Shape out_shape(shape_vec);
 
-  auto reshape = createOp<ops::ReshapeOp>("Reshape", inputs[0], out_shape);
+  auto reshape = createOp<ops::ReshapeOp>(inputs[0], out_shape);
 
   return {reshape->getOutput(0)};
 }
 
-Operation::Output *Caffe2OpCreator::createInput(const std::string &name, const mir::Shape &shape)
-{
-  return _graph->create<ops::InputOp>(name, shape)->getOutput(0);
-}
-
 } // namespace mir_caffe2
index 41cda23..57424a2 100644 (file)
@@ -41,8 +41,6 @@ class Caffe2OpCreator
 public:
   explicit Caffe2OpCreator(mir::Graph *g) : _graph(g) {}
 
-  Operation::Output *createInput(const std::string &name, const mir::Shape &shape);
-
   std::vector<mir::Operation::Output *>
   convertConstant(const std::vector<mir::Operation::Output *> &inputs,
                   const ::caffe2::OperatorDef &op);
@@ -110,17 +108,13 @@ private:
 
   mir::Operation::Output *convertMIRToCaffe(mir::Operation::Output *arg);
 
-  template <typename OpType, typename... Types>
-  mir::Operation *createOp(const std::string &name, Types &&... args);
+  template <typename OpType, typename... Types> mir::Operation *createOp(Types &&... args);
 };
 
 template <typename OpType, typename... Types>
-mir::Operation *Caffe2OpCreator::createOp(const std::string &name, Types &&... args)
+mir::Operation *Caffe2OpCreator::createOp(Types &&... args)
 {
-  mir::Operation *new_op = _graph->create<OpType>("", std::forward<Types>(args)...);
-  std::string op_name = name + "_" + std::to_string(new_op->getId());
-  new_op->setName(op_name);
-  return new_op;
+  return _graph->create<OpType>(std::forward<Types>(args)...);
 }
 
 } // namespace mir_caffe2