Release 18.08
[platform/upstream/armnn.git] / src / armnn / test / TensorTest.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include <boost/test/unit_test.hpp>
6 #include <armnn/Tensor.hpp>
7
8 namespace armnn
9 {
10
11 // Adds unit test framework for interpreting TensorInfo type.
12 std::ostream& boost_test_print_type(std::ostream& ostr, const TensorInfo& right)
13 {
14     ostr << "TensorInfo[ "
15     << right.GetNumDimensions() << ","
16     << right.GetShape()[0] << ","
17     << right.GetShape()[1] << ","
18     << right.GetShape()[2] << ","
19     << right.GetShape()[3]
20     << " ]" << std::endl;
21     return ostr;
22 }
23
24 std::ostream& boost_test_print_type(std::ostream& ostr, const TensorShape& shape)
25 {
26     ostr << "TensorShape[ "
27         << shape.GetNumDimensions() << ","
28         << shape[0] << ","
29         << shape[1] << ","
30         << shape[2] << ","
31         << shape[3]
32         << " ]" << std::endl;
33     return ostr;
34 }
35
36 } //namespace armnn
37 using namespace armnn;
38
39 BOOST_AUTO_TEST_SUITE(Tensor)
40
41 struct TensorInfoFixture
42 {
43     TensorInfoFixture()
44     {
45         unsigned int sizes[] = {6,7,8,9};
46         m_TensorInfo = TensorInfo(4, sizes, DataType::Float32);
47     }
48     ~TensorInfoFixture() {};
49
50     TensorInfo m_TensorInfo;
51 };
52
53 BOOST_FIXTURE_TEST_CASE(ConstructShapeUsingListInitialization, TensorInfoFixture)
54 {
55     TensorShape listInitializedShape{ 6, 7, 8, 9 };
56     BOOST_TEST(listInitializedShape == m_TensorInfo.GetShape());
57 }
58
59 BOOST_FIXTURE_TEST_CASE(ConstructTensorInfo, TensorInfoFixture)
60 {
61     BOOST_TEST(m_TensorInfo.GetNumDimensions() == 4);
62     BOOST_TEST(m_TensorInfo.GetShape()[0] == 6); // <= Outer most
63     BOOST_TEST(m_TensorInfo.GetShape()[1] == 7);
64     BOOST_TEST(m_TensorInfo.GetShape()[2] == 8);
65     BOOST_TEST(m_TensorInfo.GetShape()[3] == 9);     // <= Inner most
66 }
67
68 BOOST_FIXTURE_TEST_CASE(CopyConstructTensorInfo, TensorInfoFixture)
69 {
70     TensorInfo copyConstructed(m_TensorInfo);
71     BOOST_TEST(copyConstructed.GetNumDimensions() == 4);
72     BOOST_TEST(copyConstructed.GetShape()[0] == 6);
73     BOOST_TEST(copyConstructed.GetShape()[1] == 7);
74     BOOST_TEST(copyConstructed.GetShape()[2] == 8);
75     BOOST_TEST(copyConstructed.GetShape()[3] == 9);
76 }
77
78 BOOST_FIXTURE_TEST_CASE(TensorInfoEquality, TensorInfoFixture)
79 {
80     TensorInfo copyConstructed(m_TensorInfo);
81     BOOST_TEST(copyConstructed == m_TensorInfo);
82 }
83
84 BOOST_FIXTURE_TEST_CASE(TensorInfoInequality, TensorInfoFixture)
85 {
86     TensorInfo other;
87     unsigned int sizes[] = {2,3,4,5};
88     other = TensorInfo(4, sizes, DataType::Float32);
89
90     BOOST_TEST(other != m_TensorInfo);
91 }
92
93 BOOST_FIXTURE_TEST_CASE(TensorInfoAssignmentOperator, TensorInfoFixture)
94 {
95     TensorInfo copy;
96     copy = m_TensorInfo;
97     BOOST_TEST(copy == m_TensorInfo);
98 }
99
100 void CheckTensor(const ConstTensor& t)
101 {
102     t.GetInfo();
103 }
104
105 BOOST_AUTO_TEST_CASE(TensorVsConstTensor)
106 {
107     int mutableDatum = 2;
108     const int immutableDatum = 3;
109
110     armnn::Tensor uninitializedTensor;
111     armnn::ConstTensor uninitializedTensor2;
112
113     uninitializedTensor2 = uninitializedTensor;
114
115     armnn::Tensor t(TensorInfo(), &mutableDatum);
116     armnn::ConstTensor ct(TensorInfo(), &immutableDatum);
117
118     // Checks that both Tensor and ConstTensor can be passed as a ConstTensor.
119     CheckTensor(t);
120     CheckTensor(ct);
121 }
122
123 BOOST_AUTO_TEST_CASE(ModifyTensorInfo)
124 {
125     TensorInfo info;
126     info.SetShape({ 5, 6, 7, 8 });
127     BOOST_TEST((info.GetShape() == TensorShape({ 5, 6, 7, 8 })));
128     info.SetDataType(DataType::QuantisedAsymm8);
129     BOOST_TEST((info.GetDataType() == DataType::QuantisedAsymm8));
130     info.SetQuantizationScale(10.0f);
131     BOOST_TEST(info.GetQuantizationScale() == 10.0f);
132     info.SetQuantizationOffset(5);
133     BOOST_TEST(info.GetQuantizationOffset() == 5);
134 }
135
136 BOOST_AUTO_TEST_CASE(TensorShapeOperatorBrackets)
137 {
138     TensorShape shape({0,1,2,3});
139     // Checks version of operator[] which returns an unsigned int.
140     BOOST_TEST(shape[2] == 2);
141     // Checks the version of operator[] which returns a reference.
142     shape[2] = 20;
143     BOOST_TEST(shape[2] == 20);
144 }
145
146 BOOST_AUTO_TEST_SUITE_END()