Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / Helper / MatmulBCast.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_EINSUM_HELPER_MATMUL_BCAST_H__
19 #define __NNFW_CKER_EINSUM_HELPER_MATMUL_BCAST_H__
20
21 #include <vector>
22 #include <memory>
23 #include <numeric>
24
25 #include "BCast.h"
26 #include "cker/Shape.h"
27
28 namespace nnfw
29 {
30 namespace cker
31 {
32
33 // Simple wrapper over BCast specialized for MatMul.
34 // Provides utilities for broadcasting across batch dimensions for binary
35 // MatMul-like operations.
36
37 // Fix: Use Shape directly instead of Vec
38 class MatMulBCast
39 {
40 public:
41   MatMulBCast(Shape &shape_x, Shape &shape_y)
42   {
43     if (shape_x.DimensionsCount() < 2 || shape_y.DimensionsCount() < 2)
44       return;
45
46     std::vector<int32_t> x;
47     std::vector<int32_t> y;
48
49     x.resize(shape_x.DimensionsCount() - 2);
50     y.resize(shape_y.DimensionsCount() - 2);
51
52     for (size_t i = 0; i < x.size(); i++)
53     {
54       x[i] = shape_x.Dims(i);
55     }
56     for (size_t i = 0; i < y.size(); i++)
57     {
58       y[i] = shape_y.Dims(i);
59     }
60
61     _batch_bcast = std::make_unique<BCast>(std::move(x), std::move(y));
62     if (!_batch_bcast->IsValid())
63       return;
64
65     auto x_reshaped = _batch_bcast->x_reshape();
66     auto y_reshaped = _batch_bcast->y_reshape();
67     auto output_shape = _batch_bcast->output_shape();
68
69     _x_batch_size = std::accumulate(x_reshaped.cbegin(), x_reshaped.cend(), INT32_C(1),
70                                     std::multiplies<int32_t>());
71     _y_batch_size = std::accumulate(x_reshaped.cbegin(), x_reshaped.cend(), INT32_C(1),
72                                     std::multiplies<int32_t>());
73     _output_shape.ReplaceWith(output_shape.size(), output_shape.data());
74     _output_batch_size = _output_shape.FlatSize();
75   }
76
77   bool IsValid() const { return (_batch_bcast != nullptr) && _batch_bcast->IsValid(); }
78   int32_t x_batch_size() const { return _x_batch_size; }
79   int32_t y_batch_size() const { return _y_batch_size; }
80   int32_t output_batch_size() const { return _output_batch_size; }
81   const Shape &output_batch_shape() const { return _output_shape; }
82
83 private:
84   std::unique_ptr<BCast> _batch_bcast;
85
86   int32_t _x_batch_size;
87   int32_t _y_batch_size;
88   Shape _output_shape;
89   int32_t _output_batch_size;
90 };
91
92 } // namespace cker
93 } // namespace nnfw
94
95 #endif // __NNFW_CKER_EINSUM_HELPER_MATMUL_BCAST_H__