2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "NodeExecution.h"
20 #include "NodeDataImpl.h"
21 #include "NodeDomain.h"
22 #include "Validation.h"
24 #include <nncc/core/ADT/tensor/Shape.h>
25 #include <nncc/core/ADT/tensor/Buffer.h>
26 #include <nncc/core/ADT/tensor/Index.h>
27 #include <nncc/core/ADT/tensor/IndexEnumerator.h>
28 #include <nncc/core/ADT/tensor/LexicalLayout.h>
35 using nncc::core::ADT::tensor::Buffer;
36 using nncc::core::ADT::tensor::Shape;
37 using nncc::core::ADT::tensor::Index;
38 using nncc::core::ADT::tensor::LexicalLayout;
39 using nncc::core::ADT::tensor::make_buffer;
42 * @brief Calculate Matrix Multiplication
44 template <typename T> Buffer<T> calc_mat_mul(const Buffer<T> *lhs_buf, const Buffer<T> *rhs_buf)
46 const auto lhs_shape = lhs_buf->shape();
47 const auto rhs_shape = rhs_buf->shape();
49 assert(lhs_shape.rank() == 2 && "lhs rank must be 2");
50 assert(rhs_shape.rank() == 2 && "rhs rank must be 2");
51 // lhs width should be the same as rhs height
52 assert(lhs_shape.dim(1) == rhs_shape.dim(0) && "height/width mismatch");
54 const uint32_t lhs_height = lhs_shape.dim(0);
55 const uint32_t lhs_width = lhs_shape.dim(1);
57 const uint32_t rhs_width = rhs_shape.dim(1);
59 const uint32_t output_height = lhs_height;
60 const uint32_t output_width = rhs_width;
62 Shape output_shape{output_height, output_width};
63 auto output_buf = make_buffer<T, LexicalLayout>(output_shape);
65 for (uint32_t out_y = 0; out_y < output_height; ++out_y)
67 for (uint32_t out_x = 0; out_x < output_width; ++out_x)
69 T total = static_cast<T>(0); // accumulator
70 // Accumulate through axis
71 for (uint32_t axis = 0; axis < lhs_width; ++axis)
73 total += lhs_buf->at(Index({out_y, axis})) * rhs_buf->at(Index({axis, out_x}));
76 output_buf.at(Index({out_y, out_x})) = total;
88 void NodeExecution::execute(loco::MatMul *mat_mul)
90 auto lhs_data = annot_data(mat_mul->lhs());
91 auto rhs_data = annot_data(mat_mul->rhs());
93 validate(lhs_data, "Can't find left matrix data of MatMul");
94 validate(lhs_data->shape()->rank() == 2, "lhs rank must be 2");
96 validate(rhs_data, "Can't find right matrix data of MatMul");
97 validate(rhs_data->shape()->rank() == 2, "rhs rank must be 2");
99 validate(annot_domain(mat_mul->lhs()) == loco::Domain::Matrix,
100 "Left matrix of MatMul is not a Matrix");
101 validate(annot_domain(mat_mul->rhs()) == loco::Domain::Matrix,
102 "Right matrix of MatMul is not a Matrix");
104 std::unique_ptr<NodeData> mat_mul_result = nullptr;
106 if (lhs_data->dtype() == loco::DataType::FLOAT32 && rhs_data->dtype() == loco::DataType::FLOAT32)
108 const auto lhs_buf = lhs_data->as_f32_bufptr();
109 const auto rhs_buf = rhs_data->as_f32_bufptr();
111 auto mat_mul_buf = calc_mat_mul<float>(lhs_buf, rhs_buf);
113 mat_mul_result = make_data(mat_mul_buf);
115 else if (lhs_data->dtype() == loco::DataType::S32 && rhs_data->dtype() == loco::DataType::S32)
117 const auto lhs_buf = lhs_data->as_s32_bufptr();
118 const auto rhs_buf = rhs_data->as_s32_bufptr();
120 auto mat_mul_buf = calc_mat_mul<int32_t>(lhs_buf, rhs_buf);
122 mat_mul_result = make_data(mat_mul_buf);
125 throw std::runtime_error("NYI for these DataTypes");
127 assert(mat_mul_result != nullptr);
129 annot_data(mat_mul, std::move(mat_mul_result));
130 annot_domain(mat_mul, loco::Domain::Matrix);
133 } // namespace locomotiv