return CloneBase<ResizeBilinearLayer>(graph, m_Param, GetName());
}
-void ResizeBilinearLayer::ValidateTensorShapesFromInputs()
+std::vector<TensorShape> ResizeBilinearLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
- ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection() != nullptr,
- "MemCopyLayer: InputSlot must be connected to an OutputSlot");
- ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection()->IsTensorInfoSet(),
- "MemCopyLayer: TensorInfo must be set on connected OutputSlot.");
+ BOOST_ASSERT(inputShapes.size() == 1);
+ const TensorShape& inputShape = inputShapes[0];
- const TensorShape& inputShape = GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape();
unsigned int outWidth = m_Param.m_TargetWidth;
unsigned int outHeight = m_Param.m_TargetHeight;
unsigned int outChannels = inputShape[1];
unsigned int outBatch = inputShape[0];
- TensorShape outShape({outBatch, outChannels, outHeight, outWidth});
+
+ return std::vector<TensorShape>({ TensorShape({outBatch, outChannels, outHeight, outWidth}) });
+}
+
+void ResizeBilinearLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(1, CHECK_LOCATION());
+
+ auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
+
+ BOOST_ASSERT(inferredShapes.size() == 1);
+
ConditionalThrowIfNotEqual<LayerValidationException>(
"ResizeBilinearLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
GetOutputSlot(0).GetTensorInfo().GetShape(),
- outShape);
+ inferredShapes[0]);
}
} // namespace armnn