Fixed static analysis issues (transformations) (#3276)
[platform/upstream/dldt.git] / inference-engine / src / transformations / src / transformations / op_conversions / log_softmax_decomposition.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "transformations/op_conversions/log_softmax_decomposition.hpp"
6
7 #include <memory>
8
9 #include <ngraph/opsets/opset5.hpp>
10 #include <ngraph/rt_info.hpp>
11 #include <ngraph/pattern/op/wrap_type.hpp>
12
13 NGRAPH_RTTI_DEFINITION(ngraph::pass::LogSoftmaxDecomposition, "LogSoftmaxDecomposition", 0);
14
15 ngraph::pass::LogSoftmaxDecomposition::LogSoftmaxDecomposition() {
16     // Decomposes LogSoftmax(x, axis) op into sub-graph x - log(reduce_sum(exp(x), axis))
17     auto log_softmax = ngraph::pattern::wrap_type<opset5::LogSoftmax>();
18
19     ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
20         auto& pattern_to_output = m.get_pattern_value_map();
21         auto log_softmax_node = std::dynamic_pointer_cast<ngraph::opset5::LogSoftmax>(pattern_to_output.at(log_softmax).get_node_shared_ptr());
22
23         if (log_softmax_node == nullptr || m_transformation_callback(log_softmax_node)) {
24             return false;
25         }
26
27         auto axis1 = ngraph::opset5::Constant::create(element::Type_t::i64, ngraph::Shape{1}, { log_softmax_node->get_axis() });
28         auto axis2 = ngraph::opset5::Constant::create(element::Type_t::i64, ngraph::Shape{1}, { log_softmax_node->get_axis() });
29         auto max = std::make_shared<ngraph::opset5::ReduceMax>(log_softmax_node->input_value(0), axis1, true);
30         auto sub = std::make_shared<ngraph::opset5::Subtract>(log_softmax_node->input_value(0), max);
31         auto exp = std::make_shared<ngraph::opset5::Exp>(sub);
32         auto sum = std::make_shared<ngraph::opset5::ReduceSum>(exp, axis2, true);
33         auto log = std::make_shared<ngraph::opset5::Log>(sum);
34         auto sub_end = std::make_shared<ngraph::opset5::Subtract>(sub, log);
35
36         sub_end->set_friendly_name(m.get_match_root()->get_friendly_name());
37         ngraph::copy_runtime_info(log_softmax_node, { axis1, axis2, max, sub, exp, sum, log, sub_end });
38         ngraph::replace_node(m.get_match_root(), sub_end);
39         return true;
40     };
41
42     auto m = std::make_shared<ngraph::pattern::Matcher>(log_softmax, "LogSoftmaxDecomposition");
43     register_matcher(m, callback);
44 }