2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_EINSUM_HELPER_MATMUL_BCAST_H__
19 #define __NNFW_CKER_EINSUM_HELPER_MATMUL_BCAST_H__
26 #include "cker/Shape.h"
33 // Simple wrapper over BCast specialized for MatMul.
34 // Provides utilities for broadcasting across batch dimensions for binary
35 // MatMul-like operations.
37 // Fix: Use Shape directly instead of Vec
41 MatMulBCast(Shape &shape_x, Shape &shape_y)
43 if (shape_x.DimensionsCount() < 2 || shape_y.DimensionsCount() < 2)
46 std::vector<int32_t> x;
47 std::vector<int32_t> y;
49 x.resize(shape_x.DimensionsCount() - 2);
50 y.resize(shape_y.DimensionsCount() - 2);
52 for (size_t i = 0; i < x.size(); i++)
54 x[i] = shape_x.Dims(i);
56 for (size_t i = 0; i < y.size(); i++)
58 y[i] = shape_y.Dims(i);
61 _batch_bcast = std::make_unique<BCast>(std::move(x), std::move(y));
62 if (!_batch_bcast->IsValid())
65 auto x_reshaped = _batch_bcast->x_reshape();
66 auto y_reshaped = _batch_bcast->y_reshape();
67 auto output_shape = _batch_bcast->output_shape();
69 _x_batch_size = std::accumulate(x_reshaped.cbegin(), x_reshaped.cend(), INT32_C(1),
70 std::multiplies<int32_t>());
71 _y_batch_size = std::accumulate(x_reshaped.cbegin(), x_reshaped.cend(), INT32_C(1),
72 std::multiplies<int32_t>());
73 _output_shape.ReplaceWith(output_shape.size(), output_shape.data());
74 _output_batch_size = _output_shape.FlatSize();
77 bool IsValid() const { return (_batch_bcast != nullptr) && _batch_bcast->IsValid(); }
78 int32_t x_batch_size() const { return _x_batch_size; }
79 int32_t y_batch_size() const { return _y_batch_size; }
80 int32_t output_batch_size() const { return _output_batch_size; }
81 const Shape &output_batch_shape() const { return _output_shape; }
84 std::unique_ptr<BCast> _batch_bcast;
86 int32_t _x_batch_size;
87 int32_t _y_batch_size;
89 int32_t _output_batch_size;
95 #endif // __NNFW_CKER_EINSUM_HELPER_MATMUL_BCAST_H__