2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
5 #include <boost/test/unit_test.hpp>
6 #include <armnn/Tensor.hpp>
11 // Adds unit test framework for interpreting TensorInfo type.
12 std::ostream& boost_test_print_type(std::ostream& ostr, const TensorInfo& right)
14 ostr << "TensorInfo[ "
15 << right.GetNumDimensions() << ","
16 << right.GetShape()[0] << ","
17 << right.GetShape()[1] << ","
18 << right.GetShape()[2] << ","
19 << right.GetShape()[3]
24 std::ostream& boost_test_print_type(std::ostream& ostr, const TensorShape& shape)
26 ostr << "TensorShape[ "
27 << shape.GetNumDimensions() << ","
37 using namespace armnn;
39 BOOST_AUTO_TEST_SUITE(Tensor)
41 struct TensorInfoFixture
45 unsigned int sizes[] = {6,7,8,9};
46 m_TensorInfo = TensorInfo(4, sizes, DataType::Float32);
48 ~TensorInfoFixture() {};
50 TensorInfo m_TensorInfo;
53 BOOST_FIXTURE_TEST_CASE(ConstructShapeUsingListInitialization, TensorInfoFixture)
55 TensorShape listInitializedShape{ 6, 7, 8, 9 };
56 BOOST_TEST(listInitializedShape == m_TensorInfo.GetShape());
59 BOOST_FIXTURE_TEST_CASE(ConstructTensorInfo, TensorInfoFixture)
61 BOOST_TEST(m_TensorInfo.GetNumDimensions() == 4);
62 BOOST_TEST(m_TensorInfo.GetShape()[0] == 6); // <= Outer most
63 BOOST_TEST(m_TensorInfo.GetShape()[1] == 7);
64 BOOST_TEST(m_TensorInfo.GetShape()[2] == 8);
65 BOOST_TEST(m_TensorInfo.GetShape()[3] == 9); // <= Inner most
68 BOOST_FIXTURE_TEST_CASE(CopyConstructTensorInfo, TensorInfoFixture)
70 TensorInfo copyConstructed(m_TensorInfo);
71 BOOST_TEST(copyConstructed.GetNumDimensions() == 4);
72 BOOST_TEST(copyConstructed.GetShape()[0] == 6);
73 BOOST_TEST(copyConstructed.GetShape()[1] == 7);
74 BOOST_TEST(copyConstructed.GetShape()[2] == 8);
75 BOOST_TEST(copyConstructed.GetShape()[3] == 9);
78 BOOST_FIXTURE_TEST_CASE(TensorInfoEquality, TensorInfoFixture)
80 TensorInfo copyConstructed(m_TensorInfo);
81 BOOST_TEST(copyConstructed == m_TensorInfo);
84 BOOST_FIXTURE_TEST_CASE(TensorInfoInequality, TensorInfoFixture)
87 unsigned int sizes[] = {2,3,4,5};
88 other = TensorInfo(4, sizes, DataType::Float32);
90 BOOST_TEST(other != m_TensorInfo);
93 BOOST_FIXTURE_TEST_CASE(TensorInfoAssignmentOperator, TensorInfoFixture)
97 BOOST_TEST(copy == m_TensorInfo);
100 void CheckTensor(const ConstTensor& t)
105 BOOST_AUTO_TEST_CASE(TensorVsConstTensor)
107 int mutableDatum = 2;
108 const int immutableDatum = 3;
110 armnn::Tensor uninitializedTensor;
111 armnn::ConstTensor uninitializedTensor2;
113 uninitializedTensor2 = uninitializedTensor;
115 armnn::Tensor t(TensorInfo(), &mutableDatum);
116 armnn::ConstTensor ct(TensorInfo(), &immutableDatum);
118 // Checks that both Tensor and ConstTensor can be passed as a ConstTensor.
123 BOOST_AUTO_TEST_CASE(ModifyTensorInfo)
126 info.SetShape({ 5, 6, 7, 8 });
127 BOOST_TEST((info.GetShape() == TensorShape({ 5, 6, 7, 8 })));
128 info.SetDataType(DataType::QuantisedAsymm8);
129 BOOST_TEST((info.GetDataType() == DataType::QuantisedAsymm8));
130 info.SetQuantizationScale(10.0f);
131 BOOST_TEST(info.GetQuantizationScale() == 10.0f);
132 info.SetQuantizationOffset(5);
133 BOOST_TEST(info.GetQuantizationOffset() == 5);
136 BOOST_AUTO_TEST_CASE(TensorShapeOperatorBrackets)
138 TensorShape shape({0,1,2,3});
139 // Checks version of operator[] which returns an unsigned int.
140 BOOST_TEST(shape[2] == 2);
141 // Checks the version of operator[] which returns a reference.
143 BOOST_TEST(shape[2] == 20);
146 BOOST_AUTO_TEST_SUITE_END()