From 1594489a2f82cdf737fd72d26ff4c6944c37449e Mon Sep 17 00:00:00 2001 From: Ilya Churaev Date: Thu, 22 Oct 2020 13:21:23 +0300 Subject: [PATCH] Added new version of BatchNormInference (#2728) * Added new version of BatchNormInference * Fixed code style * Fixed batch norm inference v5 * Added opset4 and opset5 to IE backend * Fixed functional test * Fixed cpuFunc tests * Fixed transformation order * Try to fix validation * Revert some changes * Updated python API and added tests * Fixed code style * Fixed python code style * Disabled test --- docs/doxygen/ie_docs.xml | 2 +- docs/ops/normalization/BatchNormInference_5.md | 98 ++++++ docs/ops/opset5.md | 2 +- .../src/convert_function_to_cnn_network.cpp | 1 - .../legacy_api/src/ie_cnn_layer_builder_ngraph.cpp | 6 - .../src/readers/ir_reader/ie_ir_parser.cpp | 15 - .../op_conversions/batch_norm_decomposition.hpp | 8 + .../common_optimizations/common_optimizations.cpp | 1 + .../op_conversions/batch_norm_decomposition.cpp | 93 +++++- .../ngraph_reader/batch_norm_inference_tests.cpp | 4 +- .../tests/ngraph_functions/src/batch_norm.cpp | 2 +- ngraph/core/include/ngraph/op/batch_norm.hpp | 43 ++- ngraph/core/include/ngraph/opsets/opset5_tbl.hpp | 2 +- ngraph/core/src/op/batch_norm.cpp | 73 ++++- ngraph/python/src/ngraph/opset5/__init__.py | 2 +- ngraph/python/src/ngraph/opset5/ops.py | 26 ++ ngraph/test/backend/batch_norm.in.cpp | 101 +++++- ngraph/test/op_is.cpp | 2 +- ngraph/test/runtime/ie/unit_test.manifest | 1 + ngraph/test/runtime/interpreter/int_executable.hpp | 18 +- ngraph/test/runtime/interpreter/opset_int_tbl.hpp | 1 + ngraph/test/runtime/opset0_tbl.hpp | 2 +- ngraph/test/type_prop/batch_norm.cpp | 339 ++++++++++++++++++++- 23 files changed, 767 insertions(+), 75 deletions(-) create mode 100644 docs/ops/normalization/BatchNormInference_5.md diff --git a/docs/doxygen/ie_docs.xml b/docs/doxygen/ie_docs.xml index 5ea61f8..0406f63 100644 --- a/docs/doxygen/ie_docs.xml +++ b/docs/doxygen/ie_docs.xml @@ -113,7 +113,7 @@ - + diff --git a/docs/ops/normalization/BatchNormInference_5.md b/docs/ops/normalization/BatchNormInference_5.md new file mode 100644 index 0000000..aab4dae --- /dev/null +++ b/docs/ops/normalization/BatchNormInference_5.md @@ -0,0 +1,98 @@ +## BatchNormInference {#openvino_docs_ops_normalization_BatchNormInference_5} + +**Versioned name**: *BatchNormInference-5 + +**Category**: *Normalization* + +**Short description**: *BatchNormInference* layer normalizes a `input` tensor by `mean` and `variance`, and applies a scale (`gamma`) to it, as well as an offset (`beta`). + +**Attributes**: + +* *epsilon* + * **Description**: *epsilon* is the number to be added to the variance to avoid division by zero when normalizing a value. For example, *epsilon* equal to 0.001 means that 0.001 is added to the variance. + * **Range of values**: a positive floating-point number + * **Type**: `float` + * **Default value**: None + * **Required**: *yes* + +**Inputs** + +* **1**: `input` - input tensor with data for normalization. At least a 2D tensor of type T, the second dimension represents the channel axis and must have a span of at least 1. **Required.** +* **2**: `gamma` - gamma scaling for normalized value. A 1D tensor of type T with the same span as input's channel axis. **Required.** +* **3**: `beta` - bias added to the scaled normalized value. A 1D tensor of type T with the same span as input's channel axis.. **Required.** +* **4**: `mean` - value for mean normalization. A 1D tensor of type T with the same span as input's channel axis.. **Required.** +* **5**: `variance` - value for variance normalization. A 1D tensor of type T with the same span as input's channel axis.. **Required.** + +**Outputs** + +* **1**: The result of normalization. A tensor of the same type and shape with 1st input tensor. + +**Types** + +* *T*: any numeric type. + +**Mathematical Formulation** + +*BatchNormInference* normalizes the output in each hidden layer. +* **Input**: Values of \f$x\f$ over a mini-batch: + \f[ + \beta = \{ x_{1...m} \} + \f] +* **Parameters to learn**: \f$ \gamma, \beta\f$ +* **Output**: + \f[ + \{ o_{i} = BN_{\gamma, \beta} ( b_{i} ) \} + \f] +* **Mini-batch mean**: + \f[ + \mu_{\beta} \leftarrow \frac{1}{m}\sum_{i=1}^{m}b_{i} + \f] +* **Mini-batch variance**: + \f[ + \sigma_{\beta }^{2}\leftarrow \frac{1}{m}\sum_{i=1}^{m} ( b_{i} - \mu_{\beta} )^{2} + \f] +* **Normalize**: + \f[ + \hat{b_{i}} \leftarrow \frac{b_{i} - \mu_{\beta}}{\sqrt{\sigma_{\beta }^{2} + \epsilon }} + \f] +* **Scale and shift**: + \f[ + o_{i} \leftarrow \gamma\hat{b_{i}} + \beta = BN_{\gamma ,\beta } ( b_{i} ) + \f] + +**Example** + +```xml + + + + + 1 + 3 + 224 + 224 + + + 3 + + + 3 + + + 3 + + + 3 + + + + + 1 + 3 + 224 + 224 + + + +``` + diff --git a/docs/ops/opset5.md b/docs/ops/opset5.md index 75db702..7db25f8 100644 --- a/docs/ops/opset5.md +++ b/docs/ops/opset5.md @@ -19,7 +19,7 @@ declared in `namespace opset5`. * [Atan](arithmetic/Atan_1.md) * [Atanh](arithmetic/Atanh_3.md) * [AvgPool](pooling/AvgPool_1.md) -* [BatchNormInference](normalization/BatchNormInference_1.md) +* [BatchNormInference](normalization/BatchNormInference_5.md) * [BatchToSpace](movement/BatchToSpace_2.md) * [BinaryConvolution](convolution/BinaryConvolution_1.md) * [Broadcast](movement/Broadcast_3.md) diff --git a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp index 8a992a4..0cf25b9 100644 --- a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp +++ b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp @@ -732,7 +732,6 @@ void convertFunctionToICNNNetwork(const std::shared_ptr>(), std::make_shared>(), std::make_shared>(), - std::make_shared>(), std::make_shared>(), std::make_shared>(), std::make_shared>(), diff --git a/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp b/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp index c9adf53..b1274b9 100644 --- a/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp +++ b/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp @@ -668,12 +668,6 @@ CNNLayer::Ptr NodeConverter::createLayer(const std::shared_ } template <> -CNNLayer::Ptr NodeConverter::createLayer( - const std::shared_ptr& layer) const { - THROW_IE_EXCEPTION << "BatchNormInference operation should be fused or decomposed"; -} - -template <> CNNLayer::Ptr NodeConverter::createLayer(const std::shared_ptr& layer) const { LayerParams params = {layer->get_friendly_name(), "Squeeze", details::convertPrecision(layer->get_output_element_type(0))}; diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp index d07a3e9..a820839 100644 --- a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp +++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp @@ -394,7 +394,6 @@ std::shared_ptr V10Parser::createNode(const std::vector>("Asin"), std::make_shared>("Atan"), std::make_shared>("AvgPool"), - std::make_shared>("BatchNormInference"), std::make_shared>("Ceiling"), std::make_shared>("Clamp"), std::make_shared>("Concat"), @@ -951,20 +950,6 @@ std::shared_ptr V10Parser::LayerCreator: activations, activations_alpha, activations_beta, clip); } -// BatchNormInference layer -template <> -std::shared_ptr V10Parser::LayerCreator::createLayer( - const ngraph::OutputVector& inputs, const pugi::xml_node& node, std::istream& binStream, - const GenericLayerParams& layerParsePrms) { - checkParameters(inputs, layerParsePrms, 5); - pugi::xml_node dn = node.child("data"); - if (dn.empty()) - THROW_IE_EXCEPTION << "Cannot read parameter for " << getType() << " layer with name: " << layerParsePrms.name; - - float eps = GetFloatAttr(dn, "eps"); - return std::make_shared(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], eps); -} - // CTCGreedyDecoder layer template <> std::shared_ptr V10Parser::LayerCreator::createLayer( diff --git a/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp b/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp index 8ca8879..7845d83 100644 --- a/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp +++ b/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp @@ -11,6 +11,7 @@ #include #include +#include using namespace std; @@ -18,6 +19,7 @@ namespace ngraph { namespace pass { class TRANSFORMATIONS_API BatchNormDecomposition; +class TRANSFORMATIONS_API BatchNormV5Decomposition; } // namespace pass } // namespace ngraph @@ -27,3 +29,9 @@ public: NGRAPH_RTTI_DECLARATION; BatchNormDecomposition(); }; + +class ngraph::pass::BatchNormV5Decomposition: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + BatchNormV5Decomposition(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 059faa7..7f2a835 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -93,6 +93,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptradd_matcher(); decomp->add_matcher(); decomp->add_matcher(); + decomp->add_matcher(); decomp->set_name("ngraph::pass::CommonDecompositions"); // CF is required after all decompositions diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp index 97e3182..e76e302 100644 --- a/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp +++ b/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp @@ -8,8 +8,11 @@ #include #include +#include #include +using namespace ngraph; + NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposition", 0); ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() { @@ -43,39 +46,107 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() { const auto& input_type = m_input->get_element_type(); // scale_add = variance + eps - auto scale_add = make_shared(m_var, opset1::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()})); + auto scale_add = make_shared(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()})); // scale = sqrt(variance + eps) - auto scale = make_shared(scale_add); + auto scale = make_shared(scale_add); // Divide `gamma` by `sqrt(variance + eps)` - auto gamma_div_scale = std::make_shared(m_gamma, scale); + auto gamma_div_scale = std::make_shared(m_gamma, scale); size_t dims_to_add = m_input->get_shape().size() - 2; Shape input_aligned_shape = m_gamma->get_shape(); for (size_t i = 0; i < dims_to_add; ++i) input_aligned_shape.push_back(1); - auto new_shape = opset1::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape); + auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape); - auto gamma_div_scale_aligned = make_shared(gamma_div_scale, new_shape, true); - auto beta_aligned = make_shared(m_beta, new_shape, true); - auto mean_aligned = make_shared(m_mean, new_shape, true); + auto gamma_div_scale_aligned = make_shared(gamma_div_scale, new_shape, true); + auto beta_aligned = make_shared(m_beta, new_shape, true); + auto mean_aligned = make_shared(m_mean, new_shape, true); // input_sub_mean = input - mean - auto input_sub_mean = register_new_node(m_input, mean_aligned); + auto input_sub_mean = register_new_node(m_input, mean_aligned); // Multiply `input - mean` and `gamma / sqrt(variance + eps)` - auto mul = std::make_shared(input_sub_mean, gamma_div_scale_aligned); + auto mul = std::make_shared(input_sub_mean, gamma_div_scale_aligned); // Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta` - auto add = std::make_shared(mul, beta_aligned); + auto add = std::make_shared(mul, beta_aligned); add->set_friendly_name(m_bn->get_friendly_name()); copy_runtime_info(m_bn, {scale_add, scale, gamma_div_scale, gamma_div_scale_aligned, - beta_aligned, input_sub_mean, mul, add}); + beta_aligned, input_sub_mean, mul, add}); replace_node(m_bn, add); return true; }; + auto m = std::make_shared(bn, "BatchNormDecomposition"); + this->register_matcher(m, callback); +} + +NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5); + +ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() { + Shape shape{2, 2, 1, 1}; + auto input = make_shared(element::f32, shape); + auto mean_shape = Shape{2}; + auto mean = make_shared(element::f32, mean_shape); + auto var_shape = Shape{2}; + auto var = make_shared(element::f32, var_shape); + auto gamma_shape = Shape{2}; + auto gamma = make_shared(element::f32, gamma_shape); + auto beta_shape = Shape{2}; + auto beta = make_shared(element::f32, beta_shape); + auto bn = make_shared(input, gamma, beta, mean, var, 0.001); + + ngraph::graph_rewrite_callback callback = [this, input, gamma, beta, mean, var](ngraph::pattern::Matcher &m) { + auto pattern_map = m.get_pattern_map(); + + auto m_input = pattern_map[input]; + auto m_gamma = pattern_map[gamma]; + auto m_beta = pattern_map[beta]; + auto m_mean = pattern_map[mean]; + auto m_var = pattern_map[var]; + + // TODO: check that all input shapes are static + auto m_bn = dynamic_pointer_cast(m.get_match_root()); + if (!m_bn) { + return false; + } + + const auto& input_type = m_input->get_element_type(); + // scale_add = variance + eps + auto scale_add = make_shared(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()})); + // scale = sqrt(variance + eps) + auto scale = make_shared(scale_add); + // Divide `gamma` by `sqrt(variance + eps)` + auto gamma_div_scale = std::make_shared(m_gamma, scale); + + size_t dims_to_add = m_input->get_shape().size() - 2; + Shape input_aligned_shape = m_gamma->get_shape(); + for (size_t i = 0; i < dims_to_add; ++i) + input_aligned_shape.push_back(1); + auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape); + + auto gamma_div_scale_aligned = make_shared(gamma_div_scale, new_shape, true); + auto beta_aligned = make_shared(m_beta, new_shape, true); + auto mean_aligned = make_shared(m_mean, new_shape, true); + + // input_sub_mean = input - mean + auto input_sub_mean = register_new_node(m_input, mean_aligned); + // Multiply `input - mean` and `gamma / sqrt(variance + eps)` + auto mul = std::make_shared(input_sub_mean, gamma_div_scale_aligned); + // Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta` + auto add = std::make_shared(mul, beta_aligned); + + add->set_friendly_name(m_bn->get_friendly_name()); + + copy_runtime_info(m_bn, {scale_add, scale, gamma_div_scale, gamma_div_scale_aligned, + beta_aligned, input_sub_mean, mul, add}); + + replace_node(m_bn, add); + + return true; + }; auto m = std::make_shared(bn, "BatchNormDecomposition"); this->register_matcher(m, callback); } diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reader/batch_norm_inference_tests.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reader/batch_norm_inference_tests.cpp index 45028d1..94c8b36 100644 --- a/inference-engine/tests/functional/inference_engine/ngraph_reader/batch_norm_inference_tests.cpp +++ b/inference-engine/tests/functional/inference_engine/ngraph_reader/batch_norm_inference_tests.cpp @@ -87,8 +87,8 @@ TEST_F(NGraphReaderTests, ReadBatchNormInferenceNetwork) { - - + + 1 diff --git a/inference-engine/tests/ngraph_functions/src/batch_norm.cpp b/inference-engine/tests/ngraph_functions/src/batch_norm.cpp index 14f4035..d45972b 100644 --- a/inference-engine/tests/ngraph_functions/src/batch_norm.cpp +++ b/inference-engine/tests/ngraph_functions/src/batch_norm.cpp @@ -24,7 +24,7 @@ std::shared_ptr makeBatchNormInference(const ngraph::Output& std::uniform_real_distribution dis(0.0, 10.0); std::generate(values.begin(), values.end(), [&dis, &gen]() { return dis(gen); }); auto variance = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, !random); - return std::make_shared(data, gamma, beta, mean, variance, epsilon); + return std::make_shared(data, gamma, beta, mean, variance, epsilon); } } // namespace builder } // namespace ngraph diff --git a/ngraph/core/include/ngraph/op/batch_norm.hpp b/ngraph/core/include/ngraph/op/batch_norm.hpp index a81870d..78aba1e 100644 --- a/ngraph/core/include/ngraph/op/batch_norm.hpp +++ b/ngraph/core/include/ngraph/op/batch_norm.hpp @@ -31,8 +31,7 @@ namespace ngraph class NGRAPH_API BatchNormInference : public Op { public: - static constexpr NodeTypeInfo type_info{"BatchNormInference", 0}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; BatchNormInference() = default; /// \param input [., C, ...] /// \param gamma gamma scaling for normalized value. [C] @@ -66,6 +65,44 @@ namespace ngraph double m_epsilon; }; } // namespace v0 - using v0::BatchNormInference; + namespace v5 + { + class NGRAPH_API BatchNormInference : public Op + { + public: + NGRAPH_RTTI_DECLARATION; + BatchNormInference() = default; + /// \param input [., C, ...] + /// \param gamma gamma scaling for normalized value. [C] + /// \param beta bias added to the scaled normalized value [C] + /// \param mean value for mean normalization [C] + /// \param variance value for variance normalization [C] + /// \param epsilon Avoids divsion by 0 if input has 0 variance + BatchNormInference(const Output& input, + const Output& gamma, + const Output& beta, + const Output& mean, + const Output& variance, + double epsilon); + + bool visit_attributes(AttributeVisitor& visitor) override; + + void validate_and_infer_types() override; + + double get_eps_value() const { return m_epsilon; } + void set_eps_value(double epsilon) { m_epsilon = epsilon; } + std::shared_ptr + clone_with_new_inputs(const OutputVector& new_args) const override; + + private: + static constexpr size_t INPUT_DATA = 0; + static constexpr size_t INPUT_GAMMA = 1; + static constexpr size_t INPUT_BETA = 2; + static constexpr size_t INPUT_MEAN = 3; + static constexpr size_t INPUT_VARIANCE = 4; + + double m_epsilon; + }; + } // namespace v0 } } diff --git a/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp b/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp index 2fbc5f6..d22b0df 100644 --- a/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp +++ b/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp @@ -25,7 +25,7 @@ NGRAPH_OP(Add, ngraph::op::v1) NGRAPH_OP(Asin, ngraph::op::v0) NGRAPH_OP(Atan, ngraph::op::v0) NGRAPH_OP(AvgPool, ngraph::op::v1) -NGRAPH_OP(BatchNormInference, ngraph::op::v0) +NGRAPH_OP(BatchNormInference, ngraph::op::v5) NGRAPH_OP(BinaryConvolution, ngraph::op::v1) NGRAPH_OP(Broadcast, ngraph::op::v3) NGRAPH_OP(Bucketize, ngraph::op::v3) diff --git a/ngraph/core/src/op/batch_norm.cpp b/ngraph/core/src/op/batch_norm.cpp index 04470f9..a778c4c 100644 --- a/ngraph/core/src/op/batch_norm.cpp +++ b/ngraph/core/src/op/batch_norm.cpp @@ -23,27 +23,27 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::BatchNormInference::type_info; - -op::BatchNormInference::BatchNormInference(const Output& input, - const Output& gamma, - const Output& beta, - const Output& mean, - const Output& variance, - double epsilon) +NGRAPH_RTTI_DEFINITION(op::v0::BatchNormInference, "batchNormInference", 0); + +op::v0::BatchNormInference::BatchNormInference(const Output& input, + const Output& gamma, + const Output& beta, + const Output& mean, + const Output& variance, + double epsilon) : Op({gamma, beta, input, mean, variance}) , m_epsilon(epsilon) { constructor_validate_and_infer_types(); } -bool op::BatchNormInference::visit_attributes(AttributeVisitor& visitor) +bool op::v0::BatchNormInference::visit_attributes(AttributeVisitor& visitor) { visitor.on_attribute("epsilon", m_epsilon); return true; } -void op::BatchNormInference::validate_and_infer_types() +void op::v0::BatchNormInference::validate_and_infer_types() { element::Type result_et; PartialShape result_batch_shape; @@ -67,9 +67,60 @@ void op::BatchNormInference::validate_and_infer_types() } std::shared_ptr - op::BatchNormInference::clone_with_new_inputs(const OutputVector& new_args) const + op::v0::BatchNormInference::clone_with_new_inputs(const OutputVector& new_args) const { check_new_args_count(this, new_args); return std::make_shared( new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon); } + +NGRAPH_RTTI_DEFINITION(op::v5::BatchNormInference, "BatchNormInference", 5); + +op::v5::BatchNormInference::BatchNormInference(const Output& input, + const Output& gamma, + const Output& beta, + const Output& mean, + const Output& variance, + double epsilon) + : Op({input, gamma, beta, mean, variance}) + , m_epsilon(epsilon) +{ + constructor_validate_and_infer_types(); +} + +bool op::v5::BatchNormInference::visit_attributes(AttributeVisitor& visitor) +{ + visitor.on_attribute("epsilon", m_epsilon); + return true; +} + +void op::v5::BatchNormInference::validate_and_infer_types() +{ + element::Type result_et; + PartialShape result_batch_shape; + PartialShape result_channel_shape; // unused here + + set_output_size(1); + std::tie(result_et, result_batch_shape, result_channel_shape) = + infer_batch_norm_forward(this, + get_input_element_type(INPUT_DATA), + get_input_element_type(INPUT_GAMMA), + get_input_element_type(INPUT_BETA), + get_input_element_type(INPUT_MEAN), + get_input_element_type(INPUT_VARIANCE), + get_input_partial_shape(INPUT_DATA), + get_input_partial_shape(INPUT_GAMMA), + get_input_partial_shape(INPUT_BETA), + get_input_partial_shape(INPUT_MEAN), + get_input_partial_shape(INPUT_VARIANCE)); + + set_output_type(0, result_et, result_batch_shape); +} + +std::shared_ptr + op::v5::BatchNormInference::clone_with_new_inputs(const OutputVector& new_args) const +{ + check_new_args_count(this, new_args); + return std::make_shared( + new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), m_epsilon); +} diff --git a/ngraph/python/src/ngraph/opset5/__init__.py b/ngraph/python/src/ngraph/opset5/__init__.py index 8c115b0..0141b1e 100644 --- a/ngraph/python/src/ngraph/opset5/__init__.py +++ b/ngraph/python/src/ngraph/opset5/__init__.py @@ -26,7 +26,7 @@ from ngraph.opset3.ops import assign from ngraph.opset1.ops import atan from ngraph.opset4.ops import atanh from ngraph.opset1.ops import avg_pool -from ngraph.opset1.ops import batch_norm_inference +from ngraph.opset5.ops import batch_norm_inference from ngraph.opset2.ops import batch_to_space from ngraph.opset1.ops import binary_convolution from ngraph.opset3.ops import broadcast diff --git a/ngraph/python/src/ngraph/opset5/ops.py b/ngraph/python/src/ngraph/opset5/ops.py index 0c84162..ab6f9c5 100644 --- a/ngraph/python/src/ngraph/opset5/ops.py +++ b/ngraph/python/src/ngraph/opset5/ops.py @@ -59,6 +59,32 @@ _get_node_factory_opset5 = partial(_get_node_factory, "opset5") @nameable_op +def batch_norm_inference( + data: NodeInput, + gamma: NodeInput, + beta: NodeInput, + mean: NodeInput, + variance: NodeInput, + epsilon: float, + name: Optional[str] = None, +) -> Node: + """Perform layer normalizes a input tensor by mean and variance with appling scale and offset. + + :param data: The input tensor with data for normalization. + :param gamma: The scalar scaling for normalized value. + :param beta: The bias added to the scaled normalized value. + :param mean: The value for mean normalization. + :param variance: The value for variance normalization. + :param epsilon: The number to be added to the variance to avoid division + by zero when normalizing a value. + :param name: The optional name of the output node. + :return: The new node which performs BatchNormInference. + """ + inputs = as_nodes(data, gamma, beta, mean, variance) + return _get_node_factory_opset5().create("BatchNormInference", inputs, {"epsilon": epsilon}) + + +@nameable_op def gather_nd( data: NodeInput, indices: NodeInput, diff --git a/ngraph/test/backend/batch_norm.in.cpp b/ngraph/test/backend/batch_norm.in.cpp index 49f277c..d4b501c 100644 --- a/ngraph/test/backend/batch_norm.in.cpp +++ b/ngraph/test/backend/batch_norm.in.cpp @@ -46,7 +46,8 @@ public: auto Beta = make_shared(etype, channel_shape); auto Mean = make_shared(etype, channel_shape); auto Variance = make_shared(etype, channel_shape); - auto BN = make_shared(Input, Gamma, Beta, Mean, Variance, epsilon); + auto BN = + make_shared(Input, Gamma, Beta, Mean, Variance, epsilon); m_function = make_shared(BN, ParameterVector{Input, Gamma, Beta, Mean, Variance}); m_input = backend->create_tensor(etype, input_shape); @@ -285,7 +286,52 @@ NGRAPH_TEST(${BACKEND_NAME}, batch_norm_inference_parameters_duplication) double eps = 0.001; auto shape_r = Shape{2, 2, 2, 1}; - auto bn = make_shared(input, mvgb, mvgb, mvgb, mvgb, eps); + auto bn = make_shared(input, mvgb, mvgb, mvgb, mvgb, eps); + + auto f = make_shared(bn, ParameterVector{input, mvgb, mvgb, mvgb, mvgb}); + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + // Create some tensors for input/output + auto _input = backend->create_tensor(element::f32, input_shape); + copy_data(_input, + vector{0.54881352f, + 0.71518934f, + 0.60276335f, + 0.54488319f, + 0.42365479f, + 0.64589411f, + 0.4375872f, + 0.89177299f}); + + auto _mvgb = backend->create_tensor(element::f32, mvgb_shape); + copy_data(_mvgb, vector{1.0f, 1.0f}); + auto bn_output = backend->create_tensor(element::f32, shape_r); + + vector expected_result{0.54903894f, + 0.71533161f, + 0.60296183f, + 0.54511058f, + 0.42394274f, + 0.64607101f, + 0.43786817f, + 0.89182704f}; + auto handle = backend->compile(f); + handle->call_with_validate({bn_output}, {_input, _mvgb, _mvgb, _mvgb, _mvgb}); + + ASSERT_TRUE( + ngraph::test::all_close(expected_result, read_vector(bn_output), 1e-3f, 1e-4f)); +} + +NGRAPH_TEST(${BACKEND_NAME}, batch_norm_inference_parameters_duplication_v5) +{ + auto input_shape = Shape{2, 2, 2, 1}; + auto input = make_shared(element::f32, input_shape); + + auto mvgb_shape = Shape{2}; + auto mvgb = make_shared(element::f32, mvgb_shape); + + double eps = 0.001; + auto shape_r = Shape{2, 2, 2, 1}; + auto bn = make_shared(input, mvgb, mvgb, mvgb, mvgb, eps); auto f = make_shared(bn, ParameterVector{input, mvgb, mvgb, mvgb, mvgb}); auto backend = runtime::Backend::create("${BACKEND_NAME}"); @@ -334,7 +380,56 @@ NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_inference_b2c2h2w1) auto var = make_shared(element::f32, var_shape); double eps = 0.001; auto shape_r = Shape{2, 2, 2, 1}; - auto bn = make_shared(input, gamma, beta, mean, var, eps); + auto bn = make_shared(input, gamma, beta, mean, var, eps); + + auto f = make_shared(bn, ParameterVector{input, gamma, beta, mean, var}); + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + // Create some tensors for input/output + auto _input = backend->create_tensor(element::f32, input_shape); + copy_data(_input, + vector{0.54881352f, + 0.71518934f, + 0.60276335f, + 0.54488319f, + 0.42365479f, + 0.64589411f, + 0.4375872f, + 0.89177299f}); + + auto _gamma = backend->create_tensor(element::f32, gamma_shape); + copy_data(_gamma, vector{1.0f, 1.0f}); + auto _beta = backend->create_tensor(element::f32, beta_shape); + copy_data(_beta, vector{0.0f, 0.0f}); + auto _mean = backend->create_tensor(element::f32, mean_shape); + copy_data(_mean, vector{0.583388f, 0.619252f}); + auto _var = backend->create_tensor(element::f32, var_shape); + copy_data(_var, vector{0.0119972f, 0.0282681f}); + auto bn_output = backend->create_tensor(element::f32, shape_r); + + vector expected_result{ + -0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f}; + auto handle = backend->compile(f); + handle->call_with_validate({bn_output}, {_input, _gamma, _beta, _mean, _var}); + + ASSERT_TRUE( + ngraph::test::all_close(expected_result, read_vector(bn_output), 1e-3f, 1e-4f)); +} + +NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_inference_b2c2h2w1_v5) +{ + auto input_shape = Shape{2, 2, 2, 1}; + auto input = make_shared(element::f32, input_shape); + auto gamma_shape = Shape{2}; + auto gamma = make_shared(element::f32, gamma_shape); + auto beta_shape = Shape{2}; + auto beta = make_shared(element::f32, beta_shape); + auto mean_shape = Shape{2}; + auto mean = make_shared(element::f32, mean_shape); + auto var_shape = Shape{2}; + auto var = make_shared(element::f32, var_shape); + double eps = 0.001; + auto shape_r = Shape{2, 2, 2, 1}; + auto bn = make_shared(input, gamma, beta, mean, var, eps); auto f = make_shared(bn, ParameterVector{input, gamma, beta, mean, var}); auto backend = runtime::Backend::create("${BACKEND_NAME}"); diff --git a/ngraph/test/op_is.cpp b/ngraph/test/op_is.cpp index a4504f2..fe64a5f 100644 --- a/ngraph/test/op_is.cpp +++ b/ngraph/test/op_is.cpp @@ -85,7 +85,7 @@ namespace void op_is_BatchNormInference() { - op::BatchNormInference node; + op::v0::BatchNormInference node; EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node)); EXPECT_FALSE(op::is_binary_elementwise_comparison(&node)); diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index fd516f5..bbad53c 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -997,6 +997,7 @@ batch_norm_training_0eps_f64 # Function inputs number differ from number of given inputs batch_norm_inference_parameters_duplication +batch_norm_inference_parameters_duplication_v5 backwards_abs backwards_acos diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp index be993f6..b3b3550 100644 --- a/ngraph/test/runtime/interpreter/int_executable.hpp +++ b/ngraph/test/runtime/interpreter/int_executable.hpp @@ -250,8 +250,8 @@ protected: } case OP_TYPEID::BatchNormInference: { - const ngraph::op::BatchNormInference* bn = - static_cast(&node); + const ngraph::op::v0::BatchNormInference* bn = + static_cast(&node); reference::batch_norm_inference(bn->get_eps_value(), args[0]->get_data_ptr(), args[1]->get_data_ptr(), @@ -262,6 +262,20 @@ protected: node.get_input_shape(2)); break; } + case OP_TYPEID::BatchNormInference_v5: + { + const ngraph::op::v5::BatchNormInference* bn = + static_cast(&node); + reference::batch_norm_inference(bn->get_eps_value(), + args[1]->get_data_ptr(), + args[2]->get_data_ptr(), + args[0]->get_data_ptr(), + args[3]->get_data_ptr(), + args[4]->get_data_ptr(), + out[0]->get_data_ptr(), + node.get_input_shape(0)); + break; + } case OP_TYPEID::BroadcastLike: break; case OP_TYPEID::Ceiling: { diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index f9e1ee4..1570b69 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -57,6 +57,7 @@ NGRAPH_OP(GatherND, op::v5) NGRAPH_OP(LSTMSequence, op::v5) NGRAPH_OP(GRUSequence, op::v5) NGRAPH_OP(RNNSequence, op::v5) +NGRAPH_OP(BatchNormInference, op::v5) NGRAPH_OP(Round, op::v5) NGRAPH_OP(LogSoftmax, op::v5) #undef ID_SUFFIX diff --git a/ngraph/test/runtime/opset0_tbl.hpp b/ngraph/test/runtime/opset0_tbl.hpp index 2d91822..abf6a58 100644 --- a/ngraph/test/runtime/opset0_tbl.hpp +++ b/ngraph/test/runtime/opset0_tbl.hpp @@ -56,7 +56,7 @@ NGRAPH_OP(Add, ngraph::op) NGRAPH_OP(Asin, ngraph::op) NGRAPH_OP(Atan, ngraph::op) NGRAPH_OP(AvgPool, ngraph::op::v0) -NGRAPH_OP(BatchNormInference, ngraph::op) +NGRAPH_OP(BatchNormInference, ngraph::op::v0) NGRAPH_OP(Broadcast, ngraph::op) NGRAPH_OP(BroadcastLike, ngraph::op) NGRAPH_OP(Ceiling, ngraph::op) diff --git a/ngraph/test/type_prop/batch_norm.cpp b/ngraph/test/type_prop/batch_norm.cpp index c2cd984..0ab600a 100644 --- a/ngraph/test/type_prop/batch_norm.cpp +++ b/ngraph/test/type_prop/batch_norm.cpp @@ -41,7 +41,8 @@ TEST(type_prop, batch_norm_inference_partial_all_rank_dynamic) auto mean = make_shared(mean_et, mean_shape); auto variance = make_shared(variance_et, variance_shape); - auto bn = make_shared(data_batch, gamma, beta, mean, variance, epsilon); + auto bn = + make_shared(data_batch, gamma, beta, mean, variance, epsilon); ASSERT_EQ(bn->get_output_size(), 1); ASSERT_EQ(bn->get_output_element_type(0), data_batch_et); @@ -69,7 +70,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_ok) auto mean = make_shared(mean_et, mean_shape); auto variance = make_shared(variance_et, variance_shape); - auto bn = make_shared(data_batch, gamma, beta, mean, variance, epsilon); + auto bn = + make_shared(data_batch, gamma, beta, mean, variance, epsilon); ASSERT_EQ(bn->get_output_size(), 1); ASSERT_EQ(bn->get_output_element_type(0), data_batch_et); @@ -100,8 +102,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_zero_chan try { - auto bn = - make_shared(data_batch, gamma, beta, mean, variance, epsilon); + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); FAIL() << "Zero channel count not detected"; } catch (const NodeValidationFailure& error) @@ -134,7 +136,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_dynamic_some_rank_static auto mean = make_shared(mean_et, mean_shape); auto variance = make_shared(variance_et, variance_shape); - auto bn = make_shared(data_batch, gamma, beta, mean, variance, epsilon); + auto bn = + make_shared(data_batch, gamma, beta, mean, variance, epsilon); ASSERT_EQ(bn->get_output_size(), 1); ASSERT_EQ(bn->get_output_element_type(0), data_batch_et); @@ -163,8 +166,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_dynamic_some_rank_static try { - auto bn = - make_shared(data_batch, gamma, beta, mean, variance, epsilon); + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); FAIL() << "Wrong gamma/beta/mean/variance shape not detected"; } catch (const NodeValidationFailure& error) @@ -202,8 +205,8 @@ TEST(type_prop, try { - auto bn = - make_shared(data_batch, gamma, beta, mean, variance, epsilon); + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); FAIL() << "Inconsistent gamma/beta/mean/variance shape not detected"; } catch (const NodeValidationFailure& error) @@ -240,8 +243,8 @@ TEST(type_prop, try { - auto bn = - make_shared(data_batch, gamma, beta, mean, variance, epsilon); + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); FAIL() << "Inconsistent gamma/beta/mean/variance channel count not detected"; } catch (const NodeValidationFailure& error) @@ -275,7 +278,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_some_stat auto mean = make_shared(mean_et, mean_shape); auto variance = make_shared(variance_et, variance_shape); - auto bn = make_shared(data_batch, gamma, beta, mean, variance, epsilon); + auto bn = + make_shared(data_batch, gamma, beta, mean, variance, epsilon); ASSERT_EQ(bn->get_output_size(), 1); ASSERT_EQ(bn->get_output_element_type(0), data_batch_et); @@ -306,8 +310,315 @@ TEST(type_prop, try { - auto bn = - make_shared(data_batch, gamma, beta, mean, variance, epsilon); + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); + FAIL() << "Inconsistent input/gamma/beta/mean/variance channel count not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Input channel dimension (4) does not match " + "shape for gamma/beta/mean/variance ({3})")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, batch_norm_inference_partial_all_rank_dynamic_v5) +{ + PartialShape data_batch_shape{PartialShape::dynamic()}; + PartialShape gamma_shape{PartialShape::dynamic()}; + PartialShape beta_shape{PartialShape::dynamic()}; + PartialShape mean_shape{PartialShape::dynamic()}; + PartialShape variance_shape{PartialShape::dynamic()}; + double epsilon = 0.001; + element::Type data_batch_et = element::f32; + element::Type gamma_et = element::f32; + element::Type beta_et = element::f32; + element::Type mean_et = element::f32; + element::Type variance_et = element::f32; + + auto data_batch = make_shared(data_batch_et, data_batch_shape); + auto gamma = make_shared(gamma_et, gamma_shape); + auto beta = make_shared(beta_et, beta_shape); + auto mean = make_shared(mean_et, mean_shape); + auto variance = make_shared(variance_et, variance_shape); + + auto bn = + make_shared(data_batch, gamma, beta, mean, variance, epsilon); + + ASSERT_EQ(bn->get_output_size(), 1); + ASSERT_EQ(bn->get_output_element_type(0), data_batch_et); + ASSERT_TRUE(bn->get_output_partial_shape(0).rank().is_dynamic()); +} + +TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_ok_v5) +{ + PartialShape data_batch_shape{ + 64, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}; + PartialShape gamma_shape{PartialShape::dynamic()}; + PartialShape beta_shape{PartialShape::dynamic()}; + PartialShape mean_shape{PartialShape::dynamic()}; + PartialShape variance_shape{PartialShape::dynamic()}; + double epsilon = 0.001; + element::Type data_batch_et = element::f32; + element::Type gamma_et = element::f32; + element::Type beta_et = element::f32; + element::Type mean_et = element::f32; + element::Type variance_et = element::f32; + + auto data_batch = make_shared(data_batch_et, data_batch_shape); + auto gamma = make_shared(gamma_et, gamma_shape); + auto beta = make_shared(beta_et, beta_shape); + auto mean = make_shared(mean_et, mean_shape); + auto variance = make_shared(variance_et, variance_shape); + + auto bn = + make_shared(data_batch, gamma, beta, mean, variance, epsilon); + + ASSERT_EQ(bn->get_output_size(), 1); + ASSERT_EQ(bn->get_output_element_type(0), data_batch_et); + ASSERT_TRUE(bn->get_output_partial_shape(0).same_scheme( + PartialShape{64, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()})); +} + +TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_zero_channels_v5) +{ + PartialShape data_batch_shape{ + Dimension::dynamic(), 0, Dimension::dynamic(), Dimension::dynamic()}; + PartialShape gamma_shape{PartialShape::dynamic()}; + PartialShape beta_shape{PartialShape::dynamic()}; + PartialShape mean_shape{PartialShape::dynamic()}; + PartialShape variance_shape{PartialShape::dynamic()}; + double epsilon = 0.001; + element::Type data_batch_et = element::f32; + element::Type gamma_et = element::f32; + element::Type beta_et = element::f32; + element::Type mean_et = element::f32; + element::Type variance_et = element::f32; + + auto data_batch = make_shared(data_batch_et, data_batch_shape); + auto gamma = make_shared(gamma_et, gamma_shape); + auto beta = make_shared(beta_et, beta_shape); + auto mean = make_shared(mean_et, mean_shape); + auto variance = make_shared(variance_et, variance_shape); + + try + { + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); + FAIL() << "Zero channel count not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, batch_norm_inference_partial_input_rank_dynamic_some_rank_static_dynamic_ok_v5) +{ + PartialShape data_batch_shape{PartialShape::dynamic()}; + PartialShape gamma_shape{Dimension::dynamic()}; + PartialShape beta_shape{PartialShape::dynamic()}; + PartialShape mean_shape{Dimension::dynamic()}; + PartialShape variance_shape{PartialShape::dynamic()}; + double epsilon = 0.001; + element::Type data_batch_et = element::f32; + element::Type gamma_et = element::f32; + element::Type beta_et = element::f32; + element::Type mean_et = element::f32; + element::Type variance_et = element::f32; + + auto data_batch = make_shared(data_batch_et, data_batch_shape); + auto gamma = make_shared(gamma_et, gamma_shape); + auto beta = make_shared(beta_et, beta_shape); + auto mean = make_shared(mean_et, mean_shape); + auto variance = make_shared(variance_et, variance_shape); + + auto bn = + make_shared(data_batch, gamma, beta, mean, variance, epsilon); + + ASSERT_EQ(bn->get_output_size(), 1); + ASSERT_EQ(bn->get_output_element_type(0), data_batch_et); + ASSERT_TRUE(bn->get_output_partial_shape(0).rank().is_dynamic()); +} + +TEST(type_prop, + batch_norm_inference_partial_input_rank_dynamic_some_rank_static_dynamic_wrong_rank_v5) +{ + PartialShape data_batch_shape{PartialShape::dynamic()}; + PartialShape gamma_shape{Dimension::dynamic(), Dimension::dynamic()}; + PartialShape beta_shape{PartialShape::dynamic()}; + PartialShape mean_shape{Dimension::dynamic(), Dimension::dynamic()}; + PartialShape variance_shape{PartialShape::dynamic()}; + double epsilon = 0.001; + element::Type data_batch_et = element::f32; + element::Type gamma_et = element::f32; + element::Type beta_et = element::f32; + element::Type mean_et = element::f32; + element::Type variance_et = element::f32; + + auto data_batch = make_shared(data_batch_et, data_batch_shape); + auto gamma = make_shared(gamma_et, gamma_shape); + auto beta = make_shared(beta_et, beta_shape); + auto mean = make_shared(mean_et, mean_shape); + auto variance = make_shared(variance_et, variance_shape); + + try + { + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); + FAIL() << "Wrong gamma/beta/mean/variance shape not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("Shape for gamma/beta/mean/variance ({?,?}) does not have rank 1")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, + batch_norm_inference_partial_input_rank_dynamic_some_rank_static_dynamic_inconsistent_rank_v5) +{ + PartialShape data_batch_shape{PartialShape::dynamic()}; + PartialShape gamma_shape{3, Dimension::dynamic()}; + PartialShape beta_shape{PartialShape::dynamic()}; + PartialShape mean_shape{Dimension::dynamic()}; + PartialShape variance_shape{PartialShape::dynamic()}; + double epsilon = 0.001; + element::Type data_batch_et = element::f32; + element::Type gamma_et = element::f32; + element::Type beta_et = element::f32; + element::Type mean_et = element::f32; + element::Type variance_et = element::f32; + + auto data_batch = make_shared(data_batch_et, data_batch_shape); + auto gamma = make_shared(gamma_et, gamma_shape); + auto beta = make_shared(beta_et, beta_shape); + auto mean = make_shared(mean_et, mean_shape); + auto variance = make_shared(variance_et, variance_shape); + + try + { + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); + FAIL() << "Inconsistent gamma/beta/mean/variance shape not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Shapes for gamma/beta/mean/variance do not match")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, + batch_norm_inference_partial_input_rank_dynamic_some_static_inconsistent_channel_count_v5) +{ + PartialShape data_batch_shape{PartialShape::dynamic()}; + PartialShape gamma_shape{3}; + PartialShape beta_shape{PartialShape::dynamic()}; + PartialShape mean_shape{4}; + PartialShape variance_shape{PartialShape::dynamic()}; + double epsilon = 0.001; + element::Type data_batch_et = element::f32; + element::Type gamma_et = element::f32; + element::Type beta_et = element::f32; + element::Type mean_et = element::f32; + element::Type variance_et = element::f32; + + auto data_batch = make_shared(data_batch_et, data_batch_shape); + auto gamma = make_shared(gamma_et, gamma_shape); + auto beta = make_shared(beta_et, beta_shape); + auto mean = make_shared(mean_et, mean_shape); + auto variance = make_shared(variance_et, variance_shape); + + try + { + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); + FAIL() << "Inconsistent gamma/beta/mean/variance channel count not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Shapes for gamma/beta/mean/variance do not match")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_some_static_ok_v5) +{ + PartialShape data_batch_shape{64, Dimension::dynamic(), Dimension::dynamic(), 224}; + PartialShape gamma_shape{3}; + PartialShape beta_shape{PartialShape::dynamic()}; + PartialShape mean_shape{3}; + PartialShape variance_shape{PartialShape::dynamic()}; + double epsilon = 0.001; + element::Type data_batch_et = element::f32; + element::Type gamma_et = element::f32; + element::Type beta_et = element::f32; + element::Type mean_et = element::f32; + element::Type variance_et = element::f32; + + auto data_batch = make_shared(data_batch_et, data_batch_shape); + auto gamma = make_shared(gamma_et, gamma_shape); + auto beta = make_shared(beta_et, beta_shape); + auto mean = make_shared(mean_et, mean_shape); + auto variance = make_shared(variance_et, variance_shape); + + auto bn = + make_shared(data_batch, gamma, beta, mean, variance, epsilon); + + ASSERT_EQ(bn->get_output_size(), 1); + ASSERT_EQ(bn->get_output_element_type(0), data_batch_et); + ASSERT_TRUE(bn->get_output_partial_shape(0).same_scheme( + PartialShape{64, 3, Dimension::dynamic(), 224})); +} + +TEST( + type_prop, + batch_norm_inference_partial_input_rank_static_dynamic_some_static_inconsistent_channel_count_v5) +{ + PartialShape data_batch_shape{64, 4, Dimension::dynamic(), 224}; + PartialShape gamma_shape{3}; + PartialShape beta_shape{PartialShape::dynamic()}; + PartialShape mean_shape{3}; + PartialShape variance_shape{PartialShape::dynamic()}; + double epsilon = 0.001; + element::Type data_batch_et = element::f32; + element::Type gamma_et = element::f32; + element::Type beta_et = element::f32; + element::Type mean_et = element::f32; + element::Type variance_et = element::f32; + + auto data_batch = make_shared(data_batch_et, data_batch_shape); + auto gamma = make_shared(gamma_et, gamma_shape); + auto beta = make_shared(beta_et, beta_shape); + auto mean = make_shared(mean_et, mean_shape); + auto variance = make_shared(variance_et, variance_shape); + + try + { + auto bn = make_shared( + data_batch, gamma, beta, mean, variance, epsilon); FAIL() << "Inconsistent input/gamma/beta/mean/variance channel count not detected"; } catch (const NodeValidationFailure& error) -- 2.7.4