51e24ad445e63e54427b5541c751495a5692446f
[platform/upstream/dldt.git] / ngraph / core / src / op / product.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/product.hpp"
18 #include "itt.hpp"
19 #include "ngraph/graph_util.hpp"
20 #include "ngraph/runtime/host_tensor.hpp"
21 #include "ngraph/runtime/reference/product.hpp"
22 #include "ngraph/shape_util.hpp"
23
24 NGRAPH_SUPPRESS_DEPRECATED_START
25
26 using namespace std;
27 using namespace ngraph;
28
29 constexpr NodeTypeInfo op::v0::Product::type_info;
30
31 op::v0::Product::Product(const Output<Node>& arg, const AxisSet& reduction_axes)
32     : ArithmeticReduction(arg, reduction_axes)
33 {
34     constructor_validate_and_infer_types();
35 }
36
37 op::v0::Product::Product(const Output<Node>& arg, const Output<Node>& reduction_axes)
38     : ArithmeticReduction(arg, reduction_axes)
39 {
40     constructor_validate_and_infer_types();
41 }
42
43 shared_ptr<Node> op::v0::Product::clone_with_new_inputs(const OutputVector& new_args) const
44 {
45     check_new_args_count(this, new_args);
46     return make_shared<op::v0::Product>(new_args.at(0), get_reduction_axes());
47 }
48
49 shared_ptr<Node> op::v0::Product::get_default_value() const
50 {
51     return ngraph::make_constant_from_string("1", get_element_type(), get_shape());
52 }
53
54 namespace
55 {
56     template <element::Type_t ET>
57     bool evaluate(const HostTensorPtr& arg,
58                   const HostTensorPtr& out,
59                   const AxisSet& axes,
60                   bool keep_dims)
61     {
62         out->set_shape(reduce(arg->get_shape(), axes, keep_dims));
63         runtime::reference::product(
64             arg->get_data_ptr<ET>(), out->get_data_ptr<ET>(), arg->get_shape(), axes, keep_dims);
65         return true;
66     }
67
68     bool evaluate_product(const HostTensorPtr& arg,
69                           const HostTensorPtr& out,
70                           const AxisSet& axes,
71                           bool keep_dims)
72     {
73         bool rc = true;
74         switch (arg->get_element_type())
75         {
76             TYPE_CASE(i32)(arg, out, axes, keep_dims);
77             break;
78             TYPE_CASE(i64)(arg, out, axes, keep_dims);
79             break;
80             TYPE_CASE(u32)(arg, out, axes, keep_dims);
81             break;
82             TYPE_CASE(u64)(arg, out, axes, keep_dims);
83             break;
84             TYPE_CASE(f16)(arg, out, axes, keep_dims);
85             break;
86             TYPE_CASE(f32)(arg, out, axes, keep_dims);
87             break;
88         default: rc = false; break;
89         }
90         return rc;
91     }
92 }
93
94 bool op::v0::Product::evaluate(const HostTensorVector& outputs,
95                                const HostTensorVector& inputs) const
96 {
97     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Product::evaluate");
98     return evaluate_product(inputs[0], outputs[0], get_reduction_axes(), false);
99 }