[nnc] Fix switcher transformation for qunatization (#8240)
authorPavel Iliutchenko/AI Tools Lab /SRR/Engineer/Samsung Electronics <p.iliutchenk@samsung.com>
Wed, 23 Oct 2019 14:16:15 +0000 (17:16 +0300)
committerAlexander Efimov/./AI Tools Lab/Samsung Electronics <a.efimov@samsung.com>
Wed, 23 Oct 2019 14:16:15 +0000 (17:16 +0300)
* Fixed transpose inserters
* Fixed Conv2D and DepthwiseConv2D switchers

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/nnc/passes/transformations/DataFormatSwitcher.cpp

index 7b091e8..8ff8426 100644 (file)
@@ -86,22 +86,30 @@ void DataFormatSwitcher::cleanup() { _candidates_for_switch.clear(); }
 
 mir::Operation::Output *DataFormatSwitcher::insertTransposeBefore(mir::Operation::Output *out)
 {
+  mir::Operation::Output *new_out;
   if (_target_format == mir::DataFormat::NHWC)
-    return _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 2, 3, 1})
-        ->getOutput(0); // NCHW -> NHWC
+    new_out = _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 2, 3, 1})
+                  ->getOutput(0); // NCHW -> NHWC
   else
-    return _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 3, 1, 2})
-        ->getOutput(0); // NHWC -> NCHW
+    new_out = _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 3, 1, 2})
+                  ->getOutput(0); // NHWC -> NCHW
+  if (out->getType().isQuantized())
+    new_out->setQuantization(out->getType().getQuantization());
+  return new_out;
 }
 
 mir::Operation::Output *DataFormatSwitcher::insertTransposeAfter(mir::Operation::Output *out)
 {
+  mir::Operation::Output *new_out;
   if (_target_format == mir::DataFormat::NHWC)
-    return _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 3, 1, 2})
-        ->getOutput(0); // NHWC -> NCHW
+    new_out = _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 3, 1, 2})
+                  ->getOutput(0); // NHWC -> NCHW
   else
-    return _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 2, 3, 1})
-        ->getOutput(0); // NCHW -> NHWC
+    new_out = _graph->create<mir::ops::TransposeOp>(out, std::vector<std::size_t>{0, 2, 3, 1})
+                  ->getOutput(0); // NCHW -> NHWC
+  if (out->getType().isQuantized())
+    new_out->setQuantization(out->getType().getQuantization());
+  return new_out;
 }
 
 void DataFormatSwitcher::switchAvgPool2D(mir::ops::AvgPool2DOp *op)
@@ -128,7 +136,7 @@ void DataFormatSwitcher::switchConv2D(mir::ops::Conv2DOp *op)
   if (op->getDataFormat() == _target_format)
     return;
 
-  assert(op->getNumInputs() == 2);
+  assert(op->getNumInputs() >= 2);
   auto *input = op->getInput(0);
   auto *kernel = op->getInput(1);
 
@@ -137,9 +145,19 @@ void DataFormatSwitcher::switchConv2D(mir::ops::Conv2DOp *op)
 
   auto *trans_in = insertTransposeBefore(input);
 
-  auto new_dw_conv = _graph->create<mir::ops::Conv2DOp>(trans_in, kernel, attributes);
+  mir::Operation *new_conv;
+  if (op->getNumInputs() == 2)
+    new_conv = _graph->create<mir::ops::Conv2DOp>(trans_in, kernel, attributes);
+  else
+  {
+    auto bias = op->getInput(2);
+    new_conv = _graph->create<mir::ops::Conv2DOp>(trans_in, kernel, bias, attributes);
+  }
+
+  if (op->getOutput(0)->getType().isQuantized())
+    new_conv->getOutput(0)->setQuantization(op->getOutput(0)->getType().getQuantization());
 
-  auto *trans_out = insertTransposeAfter(new_dw_conv->getOutput(0));
+  auto *trans_out = insertTransposeAfter(new_conv->getOutput(0));
 
   _graph->replaceNode(op, trans_out->getNode());
 }
@@ -182,7 +200,7 @@ void DataFormatSwitcher::switchDepthwiseConv2D(mir::ops::DepthwiseConv2DOp *op)
   if (op->getDataFormat() == _target_format)
     return;
 
-  assert(op->getNumInputs() == 2);
+  assert(op->getNumInputs() >= 2);
   auto *input = op->getInput(0);
   auto *kernel = op->getInput(1);
 
@@ -191,7 +209,17 @@ void DataFormatSwitcher::switchDepthwiseConv2D(mir::ops::DepthwiseConv2DOp *op)
 
   auto *trans_in = insertTransposeBefore(input);
 
-  auto new_dw_conv = _graph->create<mir::ops::DepthwiseConv2DOp>(trans_in, kernel, attributes);
+  mir::Operation *new_dw_conv;
+  if (op->getNumInputs() == 2)
+    new_dw_conv = _graph->create<mir::ops::DepthwiseConv2DOp>(trans_in, kernel, attributes);
+  else
+  {
+    auto bias = op->getInput(2);
+    new_dw_conv = _graph->create<mir::ops::DepthwiseConv2DOp>(trans_in, kernel, bias, attributes);
+  }
+
+  if (op->getOutput(0)->getType().isQuantized())
+    new_dw_conv->getOutput(0)->setQuantization(op->getOutput(0)->getType().getQuantization());
 
   auto *trans_out = insertTransposeAfter(new_dw_conv->getOutput(0));