IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / armnnUtils / TensorUtils.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "TensorUtils.hpp"
7
8 namespace armnnUtils
9 {
10
11 armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
12                                   unsigned int numberOfChannels,
13                                   unsigned int height,
14                                   unsigned int width,
15                                   const armnn::DataLayout dataLayout)
16 {
17     switch (dataLayout)
18     {
19         case armnn::DataLayout::NCHW:
20             return armnn::TensorShape({numberOfBatches, numberOfChannels, height, width});
21         case armnn::DataLayout::NHWC:
22             return armnn::TensorShape({numberOfBatches, height, width, numberOfChannels});
23         default:
24             throw armnn::InvalidArgumentException("Unknown data layout ["
25                                                   + std::to_string(static_cast<int>(dataLayout)) +
26                                                   "]", CHECK_LOCATION());
27     }
28 }
29
30 armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches,
31                                 unsigned int numberOfChannels,
32                                 unsigned int height,
33                                 unsigned int width,
34                                 const armnn::DataLayout dataLayout,
35                                 const armnn::DataType dataType)
36 {
37     switch (dataLayout)
38     {
39         case armnn::DataLayout::NCHW:
40             return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, dataType);
41         case armnn::DataLayout::NHWC:
42             return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, dataType);
43         default:
44             throw armnn::InvalidArgumentException("Unknown data layout ["
45                                                   + std::to_string(static_cast<int>(dataLayout)) +
46                                                   "]", CHECK_LOCATION());
47     }
48 }
49
50 }