IVGCVSW-2144: Adding TensorUtils class
[platform/upstream/armnn.git] / src / armnnUtils / TensorUtils.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "TensorUtils.hpp"
7
8 namespace armnnUtils
9 {
10
11 armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
12                                   unsigned int numberOfChannels,
13                                   unsigned int height,
14                                   unsigned int width,
15                                   const armnn::DataLayout dataLayout)
16 {
17     switch (dataLayout)
18     {
19         case armnn::DataLayout::NCHW:
20             return armnn::TensorShape({numberOfBatches, numberOfChannels, height, width});
21         case armnn::DataLayout::NHWC:
22             return armnn::TensorShape({numberOfBatches, height, width, numberOfChannels});
23         default:
24             throw armnn::InvalidArgumentException("Unknown data layout ["
25                                                   + std::to_string(static_cast<int>(dataLayout)) +
26                                                   "]", CHECK_LOCATION());
27     }
28 }
29
30 }
31