2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
5 #include "ResizeBilinearLayer.hpp"
7 #include "LayerCloneBase.hpp"
9 #include <armnn/TypesUtils.hpp>
10 #include <backendsCommon/WorkloadData.hpp>
11 #include <backendsCommon/WorkloadFactory.hpp>
16 ResizeBilinearLayer::ResizeBilinearLayer(const ResizeBilinearDescriptor& param, const char* name)
17 : LayerWithParameters(1, 1, LayerType::ResizeBilinear, param, name)
21 std::unique_ptr<IWorkload> ResizeBilinearLayer::CreateWorkload(const Graph& graph,
22 const IWorkloadFactory& factory) const
24 ResizeBilinearQueueDescriptor descriptor;
25 return factory.CreateResizeBilinear(descriptor, PrepInfoAndDesc(descriptor, graph));
28 ResizeBilinearLayer* ResizeBilinearLayer::Clone(Graph& graph) const
30 return CloneBase<ResizeBilinearLayer>(graph, m_Param, GetName());
33 std::vector<TensorShape> ResizeBilinearLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
35 BOOST_ASSERT(inputShapes.size() == 1);
36 const TensorShape& inputShape = inputShapes[0];
38 unsigned int outWidth = m_Param.m_TargetWidth;
39 unsigned int outHeight = m_Param.m_TargetHeight;
40 unsigned int outChannels = inputShape[m_Param.m_DataLayout.GetChannelsIndex()];
41 unsigned int outBatch = inputShape[0];
43 TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
44 TensorShape( { outBatch, outHeight, outWidth, outChannels } ) :
45 TensorShape( { outBatch, outChannels, outHeight, outWidth });
47 return std::vector<TensorShape>({ tensorShape });
50 void ResizeBilinearLayer::ValidateTensorShapesFromInputs()
52 VerifyLayerConnections(1, CHECK_LOCATION());
54 auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
56 BOOST_ASSERT(inferredShapes.size() == 1);
58 ConditionalThrowIfNotEqual<LayerValidationException>(
59 "ResizeBilinearLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
60 GetOutputSlot(0).GetTensorInfo().GetShape(),