From 40e1865f110e67e9621c98412cc9f1438ada7068 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Denis=20Maksimenko/AI=20Tools=20Lab=20/SRR/Assistant=20Engi?= =?utf8?q?neer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 14 Aug 2018 16:50:23 +0300 Subject: [PATCH] Fixed FullyConnected. (#1008) Fixed ShapeInference for FullyConnected layer, fixed interpreter implementation of FUllyConnected op. Signed-off-by: Denis Maksimenko --- .../interpreter/core/include/interpreter/ops/FullyConnected.h | 6 +++--- contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/FullyConnected.h b/contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/FullyConnected.h index dda586c..12ce416 100644 --- a/contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/FullyConnected.h +++ b/contrib/nnc/libs/backend/interpreter/core/include/interpreter/ops/FullyConnected.h @@ -38,7 +38,7 @@ public: const Shape &inShape = _input.getShape(); uint32_t inRank = inShape.rank(); - assert(wShape.dim(wRank - 2) == inShape.dim(inRank - 1)); + assert(inShape.dim(inRank - 1) == wShape.dim(wRank - 2)); const uint32_t len = wShape.dim(wRank - 2); @@ -53,10 +53,10 @@ public: for (uint32_t i = 0u; i < len; ++i) { tIdx.at(wRank - 1) = i; - const T& w = weights.at(tIdx); + const T& in = _input.at(tIdx); tIdx.at(wRank - 1) = col; tIdx.at(wRank - 2) = i; - const T& in = _input.at(tIdx); + const T& w = weights.at(tIdx); tIdx.at(wRank - 2) = row; outputElement += w * in; } diff --git a/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp b/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp index 7f9f852..2f382f4 100644 --- a/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp +++ b/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp @@ -179,8 +179,8 @@ void ShapeInference::visit(ADT::INode::Ref node, ops::FullyConnectedOp &op) } Shape outShape = wShape; - outShape.dim(weightsRank - 1) = inShape.dim(weightsRank - 1); - outShape.dim(weightsRank - 2) = wShape.dim(weightsRank - 2); + outShape.dim(weightsRank - 1) = wShape.dim(weightsRank - 1); + outShape.dim(weightsRank - 2) = inShape.dim(weightsRank - 2); op.setOutputShape(0, outShape); } -- 2.7.4