a6e721a09cb4534ff4bf11399f4dff76ef2c4d74
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Mul.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Mul.h"
19
20 #include "kernels/Utils.h"
21
22 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
23
24 #include <stdexcept>
25
26 namespace luci_interpreter
27 {
28 namespace kernels
29 {
30
31 Mul::Mul(const Tensor *input1, const Tensor *input2, Tensor *output, const MulParams &params)
32     : KernelWithParams<MulParams>({input1, input2}, {output}, params)
33 {
34 }
35
36 void Mul::configure()
37 {
38   assert(input1()->element_type() == input2()->element_type());
39   output()->resize(calculateShapeForBroadcast(input1()->shape(), input2()->shape()));
40 }
41
42 void Mul::execute() const
43 {
44   switch (input1()->element_type())
45   {
46     case DataType::FLOAT32:
47       evalFloat();
48       break;
49     default:
50       throw std::runtime_error("Unsupported type.");
51   }
52 }
53
54 void Mul::evalFloat() const
55 {
56   float activation_min{};
57   float activation_max{};
58   calculateActivationRange(_params.activation, &activation_min, &activation_max);
59
60   tflite::ArithmeticParams params{};
61   params.float_activation_min = activation_min;
62   params.float_activation_max = activation_max;
63
64   const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
65       getTensorShape(input1()), getTensorShape(input2()), &params);
66
67   if (need_broadcast)
68   {
69     tflite::reference_ops::BroadcastMul4DSlow(
70         params, getTensorShape(input1()), getTensorData<float>(input1()), getTensorShape(input2()),
71         getTensorData<float>(input2()), getTensorShape(output()), getTensorData<float>(output()));
72   }
73   else
74   {
75     tflite::reference_ops::Mul(params, getTensorShape(input1()), getTensorData<float>(input1()),
76                                getTensorShape(input2()), getTensorData<float>(input2()),
77                                getTensorShape(output()), getTensorData<float>(output()));
78   }
79 }
80
81 } // namespace kernels
82 } // namespace luci_interpreter