[nnc] Fix special case of bias layer after fc layer in ACL backend (#2448)
authorEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 30 Nov 2018 12:24:54 +0000 (15:24 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Fri, 30 Nov 2018 12:24:54 +0000 (15:24 +0300)
Handle shape of bias weights with respect to previous operation.
Need to use different shape for operations after fully connected layer,
because it restores batch dimension in output tensor

Signed-off-by: Efimov Alexander <a.efimov@samsung.com>
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp

index 47c7cc9..e79db22 100644 (file)
@@ -5,6 +5,7 @@
 #include "option/Options.h"
 #include "core/modelIR/Tensor.h"
 
+#include "core/modelIR/Operation.h"
 #include "core/modelIR/operations/VariableOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/Conv2DOp.h"
@@ -278,11 +279,18 @@ void AclCppOpGenerator::visit(ops::BiasAddOp& op) {
   const auto ir_input_shape = op.getInputShape(0);
   ir_biases_shape.resize(ir_input_shape.rank());
 
-  // ACL CLArithmeticAddition supports input tensors broadcasting.
-  for (int i = 0; i < ir_input_shape.rank() - 1; ++i)
-    ir_biases_shape.dim(i) = 1;
+  // TODO remove this if after batch axis is restored in all operations in Model IR
+  if (op.getPrevNodes()[0].op->getType() == Operation::Type::fullyConnected) {
+    // Fully connected layer restores batch axis in result, so need to copy shape with redundant 1
+    // Shape transpose is needed to generate axises in reverse order
+    ir_biases_shape = transposeShape<1, 0>(op.getInputShape(0));
+  } else {
+    // ACL CLArithmeticAddition supports input tensors broadcasting.
+    for (int i = 0; i < ir_input_shape.rank() - 1; ++i)
+      ir_biases_shape.dim(i) = 1;
 
-  ir_biases_shape.dim(-1) = ir_biases.getShape().dim(0);
+    ir_biases_shape.dim(-1) = ir_biases.getShape().dim(0);
+  }
   auto biases = genTensor(operation_name + "_biases", ir_biases_shape);
 
   // Instantiate the CLArithmeticAddition object.