Release 18.08
[platform/upstream/armnn.git] / src / armnn / layers / ResizeBilinearLayer.cpp
index 204d5af..6477fa3 100644 (file)
@@ -30,23 +30,31 @@ ResizeBilinearLayer* ResizeBilinearLayer::Clone(Graph& graph) const
     return CloneBase<ResizeBilinearLayer>(graph, m_Param, GetName());
 }
 
-void ResizeBilinearLayer::ValidateTensorShapesFromInputs()
+std::vector<TensorShape> ResizeBilinearLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
 {
-    ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection() != nullptr,
-                     "MemCopyLayer: InputSlot must be connected to an OutputSlot");
-    ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection()->IsTensorInfoSet(),
-                     "MemCopyLayer: TensorInfo must be set on connected OutputSlot.");
+    BOOST_ASSERT(inputShapes.size() == 1);
+    const TensorShape& inputShape = inputShapes[0];
 
-    const TensorShape& inputShape = GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape();
     unsigned int outWidth = m_Param.m_TargetWidth;
     unsigned int outHeight = m_Param.m_TargetHeight;
     unsigned int outChannels = inputShape[1];
     unsigned int outBatch = inputShape[0];
-    TensorShape outShape({outBatch, outChannels, outHeight, outWidth});
+
+    return std::vector<TensorShape>({ TensorShape({outBatch, outChannels, outHeight, outWidth}) });
+}
+
+void ResizeBilinearLayer::ValidateTensorShapesFromInputs()
+{
+    VerifyLayerConnections(1, CHECK_LOCATION());
+
+    auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
+
+    BOOST_ASSERT(inferredShapes.size() == 1);
+
     ConditionalThrowIfNotEqual<LayerValidationException>(
         "ResizeBilinearLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
         GetOutputSlot(0).GetTensorInfo().GetShape(),
-        outShape);
+        inferredShapes[0]);
 }
 
 } // namespace armnn