Fixed FullyConnected. (#1008)
authorDenis Maksimenko/AI Tools Lab /SRR/Assistant Engineer/삼성전자 <d.maksimenko@partner.samsung.com>
Tue, 14 Aug 2018 13:50:23 +0000 (16:50 +0300)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Tue, 14 Aug 2018 13:50:23 +0000 (16:50 +0300)
Fixed ShapeInference for FullyConnected layer, fixed interpreter implementation of FUllyConnected op.

Signed-off-by: Denis Maksimenko <d.maksimenko@partner.samsung.com>
contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/FullyConnected.h
contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp

index dda586c..12ce416 100644 (file)
@@ -38,7 +38,7 @@ public:
     const Shape &inShape = _input.getShape();
     uint32_t inRank = inShape.rank();
 
-    assert(wShape.dim(wRank - 2) == inShape.dim(inRank - 1));
+    assert(inShape.dim(inRank - 1) == wShape.dim(wRank - 2));
 
     const uint32_t len = wShape.dim(wRank - 2);
 
@@ -53,10 +53,10 @@ public:
       for (uint32_t i = 0u; i < len; ++i)
       {
         tIdx.at(wRank - 1) = i;
-        const T& w = weights.at(tIdx);
+        const T& in = _input.at(tIdx);
         tIdx.at(wRank - 1) = col;
         tIdx.at(wRank - 2) = i;
-        const T& in = _input.at(tIdx);
+        const T& w = weights.at(tIdx);
         tIdx.at(wRank - 2) = row;
         outputElement += w * in;
       }
index 7f9f852..2f382f4 100644 (file)
@@ -179,8 +179,8 @@ void ShapeInference::visit(ADT::INode::Ref node, ops::FullyConnectedOp &op)
   }
 
   Shape outShape = wShape;
-  outShape.dim(weightsRank - 1) = inShape.dim(weightsRank - 1);
-  outShape.dim(weightsRank - 2) = wShape.dim(weightsRank - 2);
+  outShape.dim(weightsRank - 1) = wShape.dim(weightsRank - 1);
+  outShape.dim(weightsRank - 2) = inShape.dim(weightsRank - 2);
   op.setOutputShape(0, outShape);
 }