[mir_onnx] Refactor importing of BatchNormalization (#6466)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Fri, 9 Aug 2019 20:27:30 +0000 (23:27 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 9 Aug 2019 20:27:30 +0000 (23:27 +0300)
Structurize the code, add comments.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir-onnx-importer/Op/BatchNormalization.cpp

index e22c89d..97aa420 100644 (file)
@@ -35,36 +35,45 @@ BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node,
                                          const std::vector<mir::Operation::Output *> &inputs,
                                          mir::Graph *graph) const
 {
-  // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias
-  // 1e-05f is the default epsilon
-  float epsilon = getFloatAttribute(onnx_node, "epsilon", 1e-05f);
+  assert(inputs.size() == 5);
+  auto input = inputs[0];
+  auto scale = inputs[1];
+  auto bias = inputs[2];
+  auto mean = inputs[3];
+  auto var = inputs[4];
 
-  const auto &scale_tensor = dynamic_cast<mir::ops::ConstantOp *>(inputs[1]->getNode())->getValue();
-  const auto &bias_tensor = dynamic_cast<mir::ops::ConstantOp *>(inputs[2]->getNode())->getValue();
-  const auto &mean_tensor = dynamic_cast<mir::ops::ConstantOp *>(inputs[3]->getNode())->getValue();
-  const auto &var_tensor = dynamic_cast<mir::ops::ConstantOp *>(inputs[4]->getNode())->getValue();
+  // 1e-05f is the default epsilon.
+  const float epsilon = getFloatAttribute(onnx_node, "epsilon", 1e-05f);
 
-  // res1 = X - mean
-  mir::Tensor<float> bias_data(mean_tensor);
-  for (auto &idx : mir::ShapeRange(bias_data.getShape()))
-    bias_data.at(idx) *= -1;
+  // Y = (X - mean) * scale / sqrt(var + epsilon) + bias =
+  //   = (X + C1) * C2 + bias
+  // We need these to be constants since we are going to change them.
+  // TODO Implement the formula using ops and let the optimizer constant-fold them.
+  auto scale_op = dynamic_cast<mir::ops::ConstantOp *>(scale->getNode());
+  auto mean_op = dynamic_cast<mir::ops::ConstantOp *>(mean->getNode());
+  auto var_op = dynamic_cast<mir::ops::ConstantOp *>(var->getNode());
 
-  auto data = convertONNXToMIR(graph, inputs[0]);
-  auto mean = createOp<mir::ops::ConstantOp>(graph, mean_tensor)->getOutput(0);
-  auto result = createOp<mir::ops::AddOp>(graph, data, mean)->getOutput(0);
+  if (scale_op == nullptr || mean_op == nullptr || var_op == nullptr)
+    throw std::runtime_error(
+        "BatchNormalization: only constant 'scale', 'mean' and 'variance' inputs are supported.");
 
-  // res2 = res1 * scale / (var + epsilon)
-  mir::Tensor<float> multiplier(scale_tensor);
-  mir::Tensor<float> var_accessor(var_tensor);
-  for (auto &idx : mir::ShapeRange(scale_tensor.getShape()))
-    multiplier.at(idx) /= std::sqrt(var_accessor.at(idx) + epsilon);
-  auto scale = createOp<mir::ops::ConstantOp>(graph, scale_tensor)->getOutput(0);
-  result = createOp<mir::ops::MulOp>(graph, result, scale)->getOutput(0);
+  mir::Tensor<float> scale_accessor(scale_op->getValue());
+  mir::Tensor<float> mean_accessor(mean_op->getValue());
+  mir::Tensor<float> var_accessor(var_op->getValue());
 
-  // overall_res = res2 + bias
-  auto bias = createOp<mir::ops::ConstantOp>(graph, bias_tensor)->getOutput(0);
-  result = createOp<mir::ops::AddOp>(graph, result, bias)->getOutput(0);
+  // C1 = -mean
+  for (const auto &idx : mir::ShapeRange(mean_accessor.getShape()))
+    mean_accessor.at(idx) *= -1;
+
+  // C2 = scale / sqrt(var + epsilon)
+  for (const auto &idx : mir::ShapeRange(scale_accessor.getShape()))
+    scale_accessor.at(idx) /= std::sqrt(var_accessor.at(idx) + epsilon);
 
+  // Y = (X + C1) * C2 + bias
+  input = convertONNXToMIR(graph, input);
+  auto result = createOp<mir::ops::AddOp>(graph, input, mean)->getOutput(0);
+  result = createOp<mir::ops::MulOp>(graph, result, scale)->getOutput(0);
+  result = createOp<mir::ops::AddOp>(graph, result, bias)->getOutput(0);
   return {convertMIRToONNX(graph, result)};
 }