2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ErrorMetric.h"
19 #include <loco/IR/DataType.h>
20 #include <loco/IR/DataTypeTraits.h>
25 using namespace mpqsolver::core;
28 * @brief compare first and second operands in MAE (Mean Average Error metric)
30 float MAEMetric::compute(const WholeOutput &first, const WholeOutput &second) const
32 assert(first.size() == second.size());
35 size_t output_size = 0;
37 for (size_t sample_index = 0; sample_index < first.size(); ++sample_index)
39 assert(first[sample_index].size() == second[sample_index].size());
40 for (size_t out_index = 0; out_index < first[sample_index].size(); ++out_index)
42 const Buffer &first_elementary = first[sample_index][out_index];
43 const Buffer &second_elementary = second[sample_index][out_index];
44 assert(first_elementary.size() == second_elementary.size());
45 size_t cur_size = first_elementary.size() / loco::size(loco::DataType::FLOAT32);
47 const float *first_floats = reinterpret_cast<const float *>(first_elementary.data());
48 const float *second_floats = reinterpret_cast<const float *>(second_elementary.data());
49 for (size_t index = 0; index < cur_size; index++)
51 float ref_value = *(first_floats + index);
52 float cur_value = *(second_floats + index);
53 error += std::fabs(ref_value - cur_value);
55 output_size += cur_size;
61 throw std::runtime_error("nothing to compare");
64 return error / output_size;