Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / pass / constant_folding_logical_reduction.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "constant_folding.hpp"
18 #include "ngraph/log.hpp"
19 #include "ngraph/op/any.hpp"
20 #include "ngraph/op/reduce_logical_and.hpp"
21 #include "ngraph/op/reduce_logical_or.hpp"
22 #include "ngraph/runtime/reference/any.hpp"
23 #include "ngraph/runtime/reference/logical_reduction.hpp"
24
25 NGRAPH_SUPPRESS_DEPRECATED_START
26
27 using namespace std;
28 using namespace ngraph;
29
30 static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::Constant> constant,
31                                                                 shared_ptr<Node> reduction_node)
32 {
33     runtime::AlignedBuffer buffer(shape_size(reduction_node->get_shape()) * sizeof(char));
34     char* data_ptr = buffer.get_ptr<char>();
35
36     if (auto any = as_type_ptr<::ngraph::op::Any>(reduction_node))
37     {
38         runtime::reference::any(constant->get_data_ptr<char>(),
39                                 data_ptr,
40                                 reduction_node->get_input_shape(0),
41                                 any->get_reduction_axes(),
42                                 false);
43     }
44     else if (auto reduce_and = as_type_ptr<::ngraph::op::v1::ReduceLogicalAnd>(reduction_node))
45     {
46         const auto reduction_axes = reduce_and->get_reduction_axes();
47         const auto input_shape = reduce_and->get_input_shape(0);
48         const char* arg = constant->get_data_ptr<char>();
49
50         runtime::reference::reduce_logical_and(
51             arg, data_ptr, input_shape, reduction_axes, reduce_and->get_keep_dims());
52     }
53     else if (auto reduce_or = as_type_ptr<::ngraph::op::v1::ReduceLogicalOr>(reduction_node))
54     {
55         const auto reduction_axes = reduce_or->get_reduction_axes();
56         const auto input_shape = reduce_or->get_input_shape(0);
57         const char* arg = constant->get_data_ptr<char>();
58
59         runtime::reference::reduce_logical_or(
60             arg, data_ptr, input_shape, reduction_axes, reduce_or->get_keep_dims());
61     }
62     else
63     {
64         NGRAPH_CHECK(false,
65                      "Internal nGraph error: Ops handled in "
66                      "fold_constant_logical_reduction must be consistent with those "
67                      "matched in construct_constant_logical_reduction");
68     }
69
70     return make_shared<op::Constant>(
71         reduction_node->get_output_element_type(0), reduction_node->get_shape(), data_ptr);
72 }
73
74 void pass::ConstantFolding::construct_constant_logical_reduction()
75 {
76     auto constant_data_label = make_shared<pattern::op::Label>(
77         element::boolean, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
78     auto constant_axes_label =
79         make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
80     auto is_supported_reduction = [](std::shared_ptr<Node> n) {
81         return (pattern::has_class<::ngraph::op::Any>()(n) ||
82                 pattern::has_class<::ngraph::op::v1::ReduceLogicalAnd>()(n) ||
83                 pattern::has_class<::ngraph::op::v1::ReduceLogicalOr>()(n));
84     };
85     auto reduction =
86         std::make_shared<pattern::op::Any>(element::i32,
87                                            Shape{2},
88                                            is_supported_reduction,
89                                            NodeVector{constant_data_label, constant_axes_label});
90
91     auto constant_logical_reduction_callback = [this, constant_data_label](pattern::Matcher& m) {
92         NGRAPH_DEBUG << "In callback for constant_logical_reduction_callback against node = "
93                      << m.get_match_root()->get_name();
94
95         auto pattern_map = m.get_pattern_map();
96
97         auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
98         auto reduction_match = m.get_match_root();
99
100         if (cf_is_disabled(reduction_match))
101             return false;
102
103         NGRAPH_CHECK(revalidate_and_ensure_static(reduction_match));
104
105         replace_node(reduction_match,
106                      fold_constant_logical_reduction(constant_match, reduction_match));
107         return true;
108     };
109
110     auto logical_reduction_matcher =
111         make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantLogicalReduction");
112     NGRAPH_SUPPRESS_DEPRECATED_START
113     this->add_matcher(logical_reduction_matcher,
114                       constant_logical_reduction_callback,
115                       PassProperty::CHANGE_DYNAMIC_STATE);
116     NGRAPH_SUPPRESS_DEPRECATED_END
117 }