1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "transformations/op_conversions/log_softmax_decomposition.hpp"
9 #include <ngraph/opsets/opset5.hpp>
10 #include <ngraph/rt_info.hpp>
11 #include <ngraph/pattern/op/wrap_type.hpp>
13 NGRAPH_RTTI_DEFINITION(ngraph::pass::LogSoftmaxDecomposition, "LogSoftmaxDecomposition", 0);
15 ngraph::pass::LogSoftmaxDecomposition::LogSoftmaxDecomposition() {
16 // Decomposes LogSoftmax(x, axis) op into sub-graph x - log(reduce_sum(exp(x), axis))
17 auto log_softmax = ngraph::pattern::wrap_type<opset5::LogSoftmax>();
19 ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
20 auto& pattern_to_output = m.get_pattern_value_map();
21 auto log_softmax_node = std::dynamic_pointer_cast<ngraph::opset5::LogSoftmax>(pattern_to_output.at(log_softmax).get_node_shared_ptr());
23 if (log_softmax_node == nullptr || m_transformation_callback(log_softmax_node)) {
27 auto axis1 = ngraph::opset5::Constant::create(element::Type_t::i64, ngraph::Shape{1}, { log_softmax_node->get_axis() });
28 auto axis2 = ngraph::opset5::Constant::create(element::Type_t::i64, ngraph::Shape{1}, { log_softmax_node->get_axis() });
29 auto max = std::make_shared<ngraph::opset5::ReduceMax>(log_softmax_node->input_value(0), axis1, true);
30 auto sub = std::make_shared<ngraph::opset5::Subtract>(log_softmax_node->input_value(0), max);
31 auto exp = std::make_shared<ngraph::opset5::Exp>(sub);
32 auto sum = std::make_shared<ngraph::opset5::ReduceSum>(exp, axis2, true);
33 auto log = std::make_shared<ngraph::opset5::Log>(sum);
34 auto sub_end = std::make_shared<ngraph::opset5::Subtract>(sub, log);
36 sub_end->set_friendly_name(m.get_match_root()->get_friendly_name());
37 ngraph::copy_runtime_info(log_softmax_node, { axis1, axis2, max, sub, exp, sum, log, sub_end });
38 ngraph::replace_node(m.get_match_root(), sub_end);
42 auto m = std::make_shared<ngraph::pattern::Matcher>(log_softmax, "LogSoftmaxDecomposition");
43 register_matcher(m, callback);