2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "TensorUtils.hpp"
11 armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
12 unsigned int numberOfChannels,
15 const armnn::DataLayout dataLayout)
19 case armnn::DataLayout::NCHW:
20 return armnn::TensorShape({numberOfBatches, numberOfChannels, height, width});
21 case armnn::DataLayout::NHWC:
22 return armnn::TensorShape({numberOfBatches, height, width, numberOfChannels});
24 throw armnn::InvalidArgumentException("Unknown data layout ["
25 + std::to_string(static_cast<int>(dataLayout)) +
26 "]", CHECK_LOCATION());
30 armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches,
31 unsigned int numberOfChannels,
34 const armnn::DataLayout dataLayout,
35 const armnn::DataType dataType)
39 case armnn::DataLayout::NCHW:
40 return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, dataType);
41 case armnn::DataLayout::NHWC:
42 return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, dataType);
44 throw armnn::InvalidArgumentException("Unknown data layout ["
45 + std::to_string(static_cast<int>(dataLayout)) +
46 "]", CHECK_LOCATION());