Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / builders / batch_normalization_layer_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <string.h>
7 #include <ie_builders.hpp>
8 #include <builders/ie_batch_normalization_layer.hpp>
9
10 #include "builder_test.hpp"
11
12 using namespace testing;
13 using namespace InferenceEngine;
14
15 class BatchNormalizationLayerBuilderTest : public BuilderTestCommon {};
16
17 //TEST_F(BatchNormalizationLayerBuilderTest, cannotCreateBatchNormalizationWithoutWeightOrBiases) {
18 //    ASSERT_THROW(((Builder::Layer)Builder::BatchNormalizationLayer("in1")), InferenceEngine::details::InferenceEngineException);
19 //    ASSERT_THROW(((Builder::Layer)Builder::BatchNormalizationLayer("in1")
20 //            .setWeights(generateBlob(Precision::FP32, {3}, Layout::C))), InferenceEngine::details::InferenceEngineException);
21 //    ASSERT_THROW(((Builder::Layer)Builder::BatchNormalizationLayer("in1")
22 //            .setBiases(generateBlob(Precision::FP32, {3}, Layout::C))), InferenceEngine::details::InferenceEngineException);
23 //}
24
25 TEST_F(BatchNormalizationLayerBuilderTest, getExistsLayerFromNetworkBuilder) {
26     Builder::Network network("Test");
27     idx_t weightsId = network.addLayer(Builder::ConstLayer("weights").setData(generateBlob(Precision::FP32, {3}, Layout::C)));
28     idx_t biasesId = network.addLayer(Builder::ConstLayer("biases").setData(generateBlob(Precision::FP32, {3}, Layout::C)));
29     Builder::BatchNormalizationLayer bnBuilder("bn");
30     idx_t bnId = network.addLayer({{0}, {weightsId}, {biasesId}}, bnBuilder);
31     Builder::BatchNormalizationLayer bnBuilderFromNetwork(network.getLayer(bnId));
32     ASSERT_EQ(bnBuilderFromNetwork.getEpsilon(), bnBuilder.getEpsilon());
33     bnBuilderFromNetwork.setEpsilon(2);
34     ASSERT_NE(bnBuilderFromNetwork.getEpsilon(), bnBuilder.getEpsilon());
35     ASSERT_EQ(bnBuilderFromNetwork.getEpsilon(), network.getLayer(bnId)->getParameters()["epsilon"].as<float>());
36 }