abc63ae64d6dca0d07017e6f43bb60465cc6fac8
[platform/upstream/armnn.git] / src / armnnSerializer / test / ActivationSerializationTests.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "../Serializer.hpp"
7
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/INetwork.hpp>
10 #include <armnn/IRuntime.hpp>
11 #include <armnnDeserializer/IDeserializer.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
13
14 #include <boost/test/unit_test.hpp>
15
16 #include <sstream>
17
18 BOOST_AUTO_TEST_SUITE(SerializerTests)
19
20 class VerifyActivationName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
21 {
22 public:
23     void VisitActivationLayer(const armnn::IConnectableLayer* layer,
24                               const armnn::ActivationDescriptor& activationDescriptor,
25                               const char* name) override
26     {
27         IgnoreUnused(layer, activationDescriptor);
28         BOOST_TEST(name == "activation");
29     }
30 };
31
32 BOOST_AUTO_TEST_CASE(ActivationSerialization)
33 {
34     armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
35
36     armnn::TensorInfo inputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 1.0f, 0);
37     armnn::TensorInfo outputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 4.0f, 0);
38
39     // Construct network
40     armnn::INetworkPtr network = armnn::INetwork::Create();
41
42     armnn::ActivationDescriptor descriptor;
43     descriptor.m_Function = armnn::ActivationFunction::ReLu;
44     descriptor.m_A = 0;
45     descriptor.m_B = 0;
46
47     armnn::IConnectableLayer* const inputLayer      = network->AddInputLayer(0, "input");
48     armnn::IConnectableLayer* const activationLayer = network->AddActivationLayer(descriptor, "activation");
49     armnn::IConnectableLayer* const outputLayer     = network->AddOutputLayer(0, "output");
50
51     inputLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
52     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
53
54     activationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
55     activationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
56
57     armnnSerializer::Serializer serializer;
58     serializer.Serialize(*network);
59
60     std::stringstream stream;
61     serializer.SaveSerializedToStream(stream);
62
63     std::string const serializerString{stream.str()};
64     std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
65
66     armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector);
67
68     VerifyActivationName visitor;
69     deserializedNetwork->Accept(visitor);
70
71     armnn::IRuntime::CreationOptions options; // default options
72     armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
73     auto deserializedOptimized = Optimize(*deserializedNetwork, { armnn::Compute::CpuRef }, run->GetDeviceSpec());
74
75     armnn::NetworkId networkIdentifier;
76
77     // Load graph into runtime
78     run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
79
80     std::vector<float> inputData {0.0f, -5.3f, 42.0f, -42.0f};
81     armnn::InputTensors inputTensors
82     {
83         {0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0), inputData.data())}
84     };
85
86     std::vector<float> expectedOutputData {0.0f, 0.0f, 42.0f, 0.0f};
87
88     std::vector<float> outputData(4);
89     armnn::OutputTensors outputTensors
90     {
91         {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
92     };
93     run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
94     BOOST_CHECK_EQUAL_COLLECTIONS(outputData.begin(), outputData.end(),
95     expectedOutputData.begin(), expectedOutputData.end());
96 }
97
98 BOOST_AUTO_TEST_SUITE_END()