8f0ec4ab6249f9da59cd01fb41d3a351d1c6e04a
[platform/core/ml/nnfw.git] / runtime / libs / misc / include / misc / tensor / Zipper.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * @file Zipper.h
19  * @ingroup COM_AI_RUNTIME
20  * @brief This file contains nnfw::misc::tensor::Zipper class
21  */
22
23 #ifndef __NNFW_MISC_TENSOR_ZIPPER_H__
24 #define __NNFW_MISC_TENSOR_ZIPPER_H__
25
26 #include "misc/tensor/Index.h"
27 #include "misc/tensor/IndexIterator.h"
28 #include "misc/tensor/Reader.h"
29
30 namespace nnfw
31 {
32 namespace misc
33 {
34 namespace tensor
35 {
36
37 /**
38  * @brief Class to apply a function with three params: @c Index, elements of a tensor
39  * at passed index read by @c Reader objects
40  */
41 template <typename T> class Zipper
42 {
43 public:
44   /**
45    * @brief Construct a new @c Zipper object
46    * @param[in] shape   Shape of @c lhs and @c rhs
47    * @param[in] lhs     @c Reader object of a tensor
48    * @param[in] rhs     @c Reader object of a tensor
49    */
50   Zipper(const Shape &shape, const Reader<T> &lhs, const Reader<T> &rhs)
51       : _shape{shape}, _lhs{lhs}, _rhs{rhs}
52   {
53     // DO NOTHING
54   }
55
56 public:
57   /**
58    * @brief Apply @c cb to all elements of tensors. Elements of two tensors
59    *        at passed @c index are read by @c lhs and @c rhs
60    * @param[in] cb   Function to apply
61    * @return    N/A
62    */
63   template <typename Callable> void zip(Callable cb) const
64   {
65     iterate(_shape) <<
66         [this, &cb](const Index &index) { cb(index, _lhs.at(index), _rhs.at(index)); };
67   }
68
69 private:
70   const Shape &_shape;
71   const Reader<T> &_lhs;
72   const Reader<T> &_rhs;
73 };
74
75 /**
76  * @brief Apply @c cb by using @c lhs and @c rhs passed to the constructor of @c zipper
77  * @param[in] zipper    @c Zipper object
78  * @param[in] cb        Function to zpply using @c zip function
79  * @return @c zipper object after applying @c cb to @c zipper
80  */
81 template <typename T, typename Callable>
82 const Zipper<T> &operator<<(const Zipper<T> &zipper, Callable cb)
83 {
84   zipper.zip(cb);
85   return zipper;
86 }
87
88 /**
89  * @brief Get @c Zipper object constructed using passed params
90  * @param shape   Shape of @c lhs and @c rhs
91  * @param lhs     @c Reader object of a tensor
92  * @param rhs     @c Reader object of a tensor
93  * @return        @c Zipper object
94  */
95 template <typename T> Zipper<T> zip(const Shape &shape, const Reader<T> &lhs, const Reader<T> &rhs)
96 {
97   return Zipper<T>{shape, lhs, rhs};
98 }
99
100 } // namespace tensor
101 } // namespace misc
102 } // namespace nnfw
103
104 #endif // __NNFW_MISC_TENSOR_ZIPPER_H__