77b7315a91a79faeecd2ad5f7f231e70f014b0ec
[platform/core/ml/nnfw.git] / compiler / locomotiv / src / Node / MatMul.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "NodeExecution.h"
19
20 #include "NodeDataImpl.h"
21 #include "NodeDomain.h"
22 #include "Validation.h"
23
24 #include <nncc/core/ADT/tensor/Shape.h>
25 #include <nncc/core/ADT/tensor/Buffer.h>
26 #include <nncc/core/ADT/tensor/Index.h>
27 #include <nncc/core/ADT/tensor/IndexEnumerator.h>
28 #include <nncc/core/ADT/tensor/LexicalLayout.h>
29
30 #include <cassert>
31 #include <stdexcept>
32
33 namespace
34 {
35 using nncc::core::ADT::tensor::Buffer;
36 using nncc::core::ADT::tensor::Shape;
37 using nncc::core::ADT::tensor::Index;
38 using nncc::core::ADT::tensor::LexicalLayout;
39 using nncc::core::ADT::tensor::make_buffer;
40
41 /**
42  * @brief Calculate Matrix Multiplication
43  */
44 template <typename T> Buffer<T> calc_mat_mul(const Buffer<T> *lhs_buf, const Buffer<T> *rhs_buf)
45 {
46   const auto lhs_shape = lhs_buf->shape();
47   const auto rhs_shape = rhs_buf->shape();
48
49   assert(lhs_shape.rank() == 2 && "lhs rank must be 2");
50   assert(rhs_shape.rank() == 2 && "rhs rank must be 2");
51   // lhs width should be the same as rhs height
52   assert(lhs_shape.dim(1) == rhs_shape.dim(0) && "height/width mismatch");
53
54   const uint32_t lhs_height = lhs_shape.dim(0);
55   const uint32_t lhs_width = lhs_shape.dim(1);
56
57   const uint32_t rhs_width = rhs_shape.dim(1);
58
59   const uint32_t output_height = lhs_height;
60   const uint32_t output_width = rhs_width;
61
62   Shape output_shape{output_height, output_width};
63   auto output_buf = make_buffer<T, LexicalLayout>(output_shape);
64
65   for (uint32_t out_y = 0; out_y < output_height; ++out_y)
66   {
67     for (uint32_t out_x = 0; out_x < output_width; ++out_x)
68     {
69       T total = static_cast<T>(0); // accumulator
70       // Accumulate through axis
71       for (uint32_t axis = 0; axis < lhs_width; ++axis)
72       {
73         total += lhs_buf->at(Index({out_y, axis})) * rhs_buf->at(Index({axis, out_x}));
74       }
75       // Set output value
76       output_buf.at(Index({out_y, out_x})) = total;
77     }
78   }
79
80   return output_buf;
81 }
82
83 } // namespace
84
85 namespace locomotiv
86 {
87
88 void NodeExecution::execute(loco::MatMul *mat_mul)
89 {
90   auto lhs_data = annot_data(mat_mul->lhs());
91   auto rhs_data = annot_data(mat_mul->rhs());
92
93   validate(lhs_data, "Can't find left matrix data of MatMul");
94   validate(lhs_data->shape()->rank() == 2, "lhs rank must be 2");
95
96   validate(rhs_data, "Can't find right matrix data of MatMul");
97   validate(rhs_data->shape()->rank() == 2, "rhs rank must be 2");
98
99   validate(annot_domain(mat_mul->lhs()) == loco::Domain::Matrix,
100            "Left matrix of MatMul is not a Matrix");
101   validate(annot_domain(mat_mul->rhs()) == loco::Domain::Matrix,
102            "Right matrix of MatMul is not a Matrix");
103
104   std::unique_ptr<NodeData> mat_mul_result = nullptr;
105
106   if (lhs_data->dtype() == loco::DataType::FLOAT32 && rhs_data->dtype() == loco::DataType::FLOAT32)
107   {
108     const auto lhs_buf = lhs_data->as_f32_bufptr();
109     const auto rhs_buf = rhs_data->as_f32_bufptr();
110
111     auto mat_mul_buf = calc_mat_mul<float>(lhs_buf, rhs_buf);
112
113     mat_mul_result = make_data(mat_mul_buf);
114   }
115   else if (lhs_data->dtype() == loco::DataType::S32 && rhs_data->dtype() == loco::DataType::S32)
116   {
117     const auto lhs_buf = lhs_data->as_s32_bufptr();
118     const auto rhs_buf = rhs_data->as_s32_bufptr();
119
120     auto mat_mul_buf = calc_mat_mul<int32_t>(lhs_buf, rhs_buf);
121
122     mat_mul_result = make_data(mat_mul_buf);
123   }
124   else
125     throw std::runtime_error("NYI for these DataTypes");
126
127   assert(mat_mul_result != nullptr);
128
129   annot_data(mat_mul, std::move(mat_mul_result));
130   annot_domain(mat_mul, loco::Domain::Matrix);
131 }
132
133 } // namespace locomotiv