2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_ONEHOT_H__
19 #define __NNFW_CKER_ONEHOT_H__
21 #include "cker/Shape.h"
28 template <typename T, typename TI>
29 void OneHot(const int32_t depth, const T on_value, const T off_value, int32_t axis,
30 const Shape &indices_shape, const TI *indices_data, const Shape &, T *output_data)
33 axis = indices_shape.DimensionsCount();
35 // prefix_dim_size == # of elements before the axis
36 // depth == # of elements per axis
37 // suffix_dim_size == # of elements after the axis
38 int prefix_dim_size = 1;
39 for (int i = 0; i < axis; ++i)
41 prefix_dim_size *= indices_shape.Dims(i);
43 const int suffix_dim_size = indices_shape.FlatSize() / prefix_dim_size;
45 // View the indices as a matrix of size:
46 // prefix_dim_size x suffix_dim_size
47 // View the output as a matrix of size:
48 // prefix_dim_size x depth x suffix_dim_size
49 // Then the output is:
50 // output(i, j, k) == (indices(i, k) == j) ? on : off
51 for (int i = 0; i < prefix_dim_size; ++i)
53 for (int j = 0; j < depth; ++j)
55 for (int k = 0; k < suffix_dim_size; ++k, ++output_data)
58 static_cast<int>(indices_data[i * suffix_dim_size + k]) == j ? on_value : off_value;
67 #endif // __NNFW_CKER_ONEHOT_H__