[nnc] Support batch in InnerProduct layer in Caffe importer (#2599)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Tue, 11 Dec 2018 08:22:34 +0000 (11:22 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Tue, 11 Dec 2018 08:22:34 +0000 (11:22 +0300)
Add support for batch size to InnerProduct layer.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp

index 283134b..0360361 100644 (file)
@@ -255,8 +255,8 @@ CaffeOpCreator::convertDeconvolution(const caffe::LayerParameter& layer,
 
 void CaffeOpCreator::checkInnerProduct(const InnerProductParameter& opts,
                                        std::set<std::string>& problemsOpSet) {
-  if (opts.has_axis() && opts.axis() != 1)
-    problemsOpSet.insert("Fully Connected: layer axis param is not supported yet");
+  if (opts.axis() != 1)
+    problemsOpSet.insert("InnerProduct: unsupported axis");
 }
 
 /**
@@ -276,13 +276,11 @@ CaffeOpCreator::convertInnerProduct(const LayerParameter& layer,
   if (!opts.transpose())
     weights = transposeTensor<1, 0>(weights);
 
-  // Add Reshape operation to make sure the input for FC operation has shape [1, fcInputSize]
-  // It is needed because Caffe InnerProduct layer takes NCHW input and flattens the CHW part.
-  int32_t fc_input_size = static_cast<int32_t>(
-                                  weights->getShape().numElements()) / opts.num_output();
-  auto reshape = createOp<ops::ReshapeOp>(layer.name() + ".reshape", inputs[0],
-                                          Shape{1, fc_input_size});
-
+  auto& input_shape = inputs[0].op->getOutputShape(inputs[0].index);
+  // Transform input into 2-D tensor by flattening axes before/after opts.axis().
+  assert(opts.axis() == 1);
+  Shape shape{input_shape.dim(0), input_shape.numElements() / input_shape.dim(0)};
+  auto reshape = createOp<ops::ReshapeOp>(layer.name() + ".reshape", inputs[0], shape);
   auto fully_connected = createOp<ops::FullyConnectedOp>(layer.name() + ".fc",
                                                          reshape->getOutput(0), *weights);