Release 18.08
[platform/upstream/armnn.git] / src / armnn / layers / FullyConnectedLayer.cpp
index 1597e8c..8b8f010 100644 (file)
@@ -22,11 +22,15 @@ FullyConnectedLayer::FullyConnectedLayer(const FullyConnectedDescriptor& param,
 std::unique_ptr<IWorkload> FullyConnectedLayer::CreateWorkload(const Graph& graph,
                                                                const IWorkloadFactory& factory) const
 {
+    // on this level constant data should not be released..
+    BOOST_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null.");
+
     FullyConnectedQueueDescriptor descriptor;
 
     descriptor.m_Weight = m_Weight.get();
     if (m_Param.m_BiasEnabled)
     {
+        BOOST_ASSERT_MSG(m_Bias != nullptr, "FullyConnectedLayer: Bias data should not be null.");
         descriptor.m_Bias = m_Bias.get();
     }
     return factory.CreateFullyConnected(descriptor, PrepInfoAndDesc(descriptor, graph));
@@ -45,25 +49,41 @@ FullyConnectedLayer* FullyConnectedLayer::Clone(Graph& graph) const
     return std::move(layer);
 }
 
+std::vector<TensorShape> FullyConnectedLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+    BOOST_ASSERT(inputShapes.size() == 2);
+    const TensorShape& inputShape = inputShapes[0];
+    const TensorShape weightShape = inputShapes[1];
+
+    // Output for FC is [1, w[1]].
+    unsigned int batches = inputShape[0];
+    unsigned int dimIdx = m_Param.m_TransposeWeightMatrix ? 0 : 1;
+
+    return std::vector<TensorShape>({ TensorShape({batches, weightShape[dimIdx]})});
+}
+
 void FullyConnectedLayer::ValidateTensorShapesFromInputs()
 {
-    ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection() != nullptr,
-                     "FullyConnectedLayer: InputSlot must be connected to an OutputSlot");
-    ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection()->IsTensorInfoSet(),
-                     "FullyConnectedLayer: TensorInfo must be set on connected OutputSlot.");
+    VerifyLayerConnections(1, CHECK_LOCATION());
 
+    // check if we m_Weight data is not nullptr
+    BOOST_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null.");
 
-    TensorShape const& weightShape = m_Weight->GetTensorInfo().GetShape();
+    auto inferredShapes = InferOutputShapes({
+        GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+        m_Weight->GetTensorInfo().GetShape() });
 
-    // output for FC is [1, w[1]]
-    unsigned int batches = GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape()[0];
-    unsigned int dimIdx = m_Param.m_TransposeWeightMatrix ? 0 : 1;
-    TensorShape outShape({batches, weightShape[dimIdx]});
+    BOOST_ASSERT(inferredShapes.size() == 1);
 
     ConditionalThrowIfNotEqual<LayerValidationException>(
         "FullyConnectedLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
         GetOutputSlot(0).GetTensorInfo().GetShape(),
-        outShape);
+        inferredShapes[0]);
+}
+
+Layer::ConstantTensors FullyConnectedLayer::GetConstantTensorsByRef()
+{
+    return {m_Weight, m_Bias};
 }
 
 } // namespace armnn