2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "../Serializer.hpp"
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/INetwork.hpp>
10 #include <armnn/IRuntime.hpp>
11 #include <armnnDeserializer/IDeserializer.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
14 #include <boost/test/unit_test.hpp>
18 BOOST_AUTO_TEST_SUITE(SerializerTests)
20 class VerifyActivationName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
23 void VisitActivationLayer(const armnn::IConnectableLayer* layer,
24 const armnn::ActivationDescriptor& activationDescriptor,
25 const char* name) override
27 IgnoreUnused(layer, activationDescriptor);
28 BOOST_TEST(name == "activation");
32 BOOST_AUTO_TEST_CASE(ActivationSerialization)
34 armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
36 armnn::TensorInfo inputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 1.0f, 0);
37 armnn::TensorInfo outputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 4.0f, 0);
40 armnn::INetworkPtr network = armnn::INetwork::Create();
42 armnn::ActivationDescriptor descriptor;
43 descriptor.m_Function = armnn::ActivationFunction::ReLu;
47 armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0, "input");
48 armnn::IConnectableLayer* const activationLayer = network->AddActivationLayer(descriptor, "activation");
49 armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output");
51 inputLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
52 inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
54 activationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
55 activationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
57 armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
59 serializer->Serialize(*network);
61 std::stringstream stream;
62 serializer->SaveSerializedToStream(stream);
64 std::string const serializerString{stream.str()};
65 std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
67 armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector);
69 VerifyActivationName visitor;
70 deserializedNetwork->Accept(visitor);
72 armnn::IRuntime::CreationOptions options; // default options
73 armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
74 auto deserializedOptimized = Optimize(*deserializedNetwork, { armnn::Compute::CpuRef }, run->GetDeviceSpec());
76 armnn::NetworkId networkIdentifier;
78 // Load graph into runtime
79 run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
81 std::vector<float> inputData {0.0f, -5.3f, 42.0f, -42.0f};
82 armnn::InputTensors inputTensors
84 {0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0), inputData.data())}
87 std::vector<float> expectedOutputData {0.0f, 0.0f, 42.0f, 0.0f};
89 std::vector<float> outputData(4);
90 armnn::OutputTensors outputTensors
92 {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
94 run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
95 BOOST_CHECK_EQUAL_COLLECTIONS(outputData.begin(), outputData.end(),
96 expectedOutputData.begin(), expectedOutputData.end());
99 BOOST_AUTO_TEST_SUITE_END()